Repository: Lightning-AI/litgpt Branch: main Commit: 162ad9bee317 Files: 233 Total size: 1.8 MB Directory structure: gitextract_ctr2cg_x/ ├── .devcontainer/ │ ├── Dockerfile │ └── devcontainer.json ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── ask-a-question.md │ │ ├── bug-report.yaml │ │ └── feature-request.md │ ├── dependabot.yml │ └── workflows/ │ ├── check-links.yml │ ├── cpu-tests.yml │ ├── mkdocs-deploy.yml │ └── publish-pkg.yml ├── .gitignore ├── .lightning/ │ └── workflows/ │ └── tests.yaml ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── README.md ├── config_hub/ │ ├── finetune/ │ │ ├── README.md │ │ ├── falcon-7b/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── gemma-2b/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── gemma-7b/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── gemma2-2b/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── gemma2-9b/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── llama-2-7b/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── llama-3-8b/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── llama-3.1-8b/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── llama-3.2-1B/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── llama-3.2-3B/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── mistral-7b/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── mistral-7b-v0.2/ │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── phi-2/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── phi-3/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ ├── stablelm-base-alpha-3b/ │ │ │ ├── full.yaml │ │ │ ├── lora.yaml │ │ │ └── qlora.yaml │ │ └── tiny-llama/ │ │ ├── full.yaml │ │ ├── lora.yaml │ │ └── qlora.yaml │ └── pretrain/ │ ├── debug.yaml │ ├── microllama.yaml │ ├── tinyllama.yaml │ └── tinystories.yaml ├── extensions/ │ ├── thunder/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── pretrain.py │ │ ├── strategies/ │ │ │ ├── __init__.py │ │ │ ├── thunder_ddp.py │ │ │ └── thunder_fsdp.py │ │ └── unsloth/ │ │ ├── __init__.py │ │ ├── executor.py │ │ └── kernels/ │ │ ├── __init__.py │ │ ├── cross_entropy_loss.py │ │ ├── rope_embedding.py │ │ ├── swiglu.py │ │ └── utils.py │ └── xla/ │ ├── README.md │ ├── __init__ │ ├── finetune/ │ │ ├── __init__ │ │ └── adapter.py │ ├── generate/ │ │ ├── __init__ │ │ ├── adapter.py │ │ └── base.py │ ├── scripts/ │ │ ├── __init__ │ │ └── prepare_alpaca.py │ └── utils.py ├── litgpt/ │ ├── __init__.py │ ├── __main__.py │ ├── adapter.py │ ├── adapter_v2.py │ ├── api.py │ ├── args.py │ ├── chat/ │ │ ├── __init__.py │ │ └── base.py │ ├── config.py │ ├── constants.py │ ├── data/ │ │ ├── __init__.py │ │ ├── alpaca.py │ │ ├── alpaca_2k.py │ │ ├── alpaca_gpt4.py │ │ ├── base.py │ │ ├── deita.py │ │ ├── flan.py │ │ ├── json_data.py │ │ ├── lima.py │ │ ├── lit_data.py │ │ ├── longform.py │ │ ├── microllama.py │ │ ├── openwebtext.py │ │ ├── prepare_slimpajama.py │ │ ├── prepare_starcoder.py │ │ ├── text_files.py │ │ ├── tinyllama.py │ │ └── tinystories.py │ ├── deploy/ │ │ ├── __init__.py │ │ └── serve.py │ ├── eval/ │ │ └── evaluate.py │ ├── finetune/ │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── adapter_v2.py │ │ ├── full.py │ │ ├── lora.py │ │ └── lora_legacy.py │ ├── generate/ │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── adapter_v2.py │ │ ├── base.py │ │ ├── full.py │ │ ├── sequentially.py │ │ ├── speculative_decoding.py │ │ └── tp.py │ ├── lora.py │ ├── model.py │ ├── parser_config.py │ ├── pretrain.py │ ├── prompts.py │ ├── scripts/ │ │ ├── __init__.py │ │ ├── convert_hf_checkpoint.py │ │ ├── convert_lit_checkpoint.py │ │ ├── convert_pretrained_checkpoint.py │ │ ├── download.py │ │ └── merge_lora.py │ ├── tokenizer.py │ ├── types.py │ └── utils.py ├── pyproject.toml ├── tests/ │ ├── conftest.py │ ├── convert/ │ │ ├── __init__.py │ │ ├── test_hf_checkpoint.py │ │ ├── test_lit_checkpoint.py │ │ └── test_pretrained_checkpoint.py │ ├── data/ │ │ ├── __init__.py │ │ ├── _fixtures/ │ │ │ ├── alpaca.json │ │ │ ├── dolly.json │ │ │ ├── longform_train.json │ │ │ └── longform_val.json │ │ ├── test_alpaca.py │ │ ├── test_base.py │ │ ├── test_deita.py │ │ ├── test_json.py │ │ ├── test_lit_data.py │ │ ├── test_longform.py │ │ ├── test_openwebtext.py │ │ ├── test_textfiles.py │ │ ├── test_tinyllama.py │ │ └── test_tinystories.py │ ├── ext_thunder/ │ │ ├── __init__.py │ │ ├── test_thunder_distributed.py │ │ ├── test_thunder_networks.py │ │ ├── test_thunder_pretrain.py │ │ └── test_unsloth_executor.py │ ├── generate/ │ │ ├── __init__.py │ │ ├── test_adapter.py │ │ ├── test_main.py │ │ ├── test_sequentially.py │ │ ├── test_tp.py │ │ └── utils.py │ ├── test_adapter.py │ ├── test_adapter_v2.py │ ├── test_api.py │ ├── test_args.py │ ├── test_batch.py │ ├── test_chat.py │ ├── test_ci.py │ ├── test_cli.py │ ├── test_config.py │ ├── test_config_hub.py │ ├── test_deepseek_moe.py │ ├── test_distributed.py │ ├── test_evaluate.py │ ├── test_full.py │ ├── test_generate_speculatively.py │ ├── test_lora.py │ ├── test_merge_lora.py │ ├── test_model.py │ ├── test_multihead_latent_attention.py │ ├── test_pretrain.py │ ├── test_prompts.py │ ├── test_readme.py │ ├── test_rope.py │ ├── test_serve.py │ ├── test_tokenizer.py │ ├── test_trainer_support.py │ ├── test_types.py │ ├── test_utils.py │ └── test_yarn.py └── tutorials/ ├── 0_to_litgpt.md ├── convert_hf_checkpoint.md ├── convert_lit_models.md ├── deploy.md ├── developer-docs/ │ ├── README.md │ ├── adding-models.md │ └── python-api.md ├── download_model_weights.md ├── evaluation.md ├── examples/ │ └── ptl-trainer/ │ ├── README.md │ ├── litgpt_ptl_medium.py │ └── litgpt_ptl_small.py ├── finetune.md ├── finetune_adapter.md ├── finetune_full.md ├── finetune_lora.md ├── full_finetune_example.py ├── inference.md ├── mkdocs.yml ├── oom.md ├── prepare_dataset.md ├── pretrain.md ├── pretrain_tinyllama.md ├── python-api.md ├── quantize.md └── resource-tables.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .devcontainer/Dockerfile ================================================ # See here for image contents: https://github.com/devcontainers/images/blob/main/src/python/.devcontainer/Dockerfile # [Choice] Python version (use -bookworm or -bullseye variants on local arm64/Apple Silicon): 3, 3.12, 3.11, 3.10, 3.9, 3.8, 3-bookworm, 3.12-bookworm, 3.11-bookworm, 3.10-bookworm, 3.9-bookworm, 3.8-bookworm, 3-bullseye, 3.12-bullseye, 3.11-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3-buster, 3.12-buster, 3.11-buster, 3.10-buster, 3.9-buster, 3.8-buster ARG VARIANT=3-bookworm FROM mcr.microsoft.com/devcontainers/python:1-${VARIANT} # Temporary: Upgrade python packages due to https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-40897 # They are installed by the base image (python) which does not have the patch. RUN python3 -m pip install --upgrade pip setuptools ================================================ FILE: .devcontainer/devcontainer.json ================================================ // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: // https://github.com/microsoft/vscode-dev-containers/tree/v0.194.0/containers/python-3 { "name": "Python 3 (litgpt)", "build": { "dockerfile": "Dockerfile", "context": "..", "args": { "VARIANT": "3.11-bookworm" } }, "runArgs": [ // Enable GPU passthrough, requires WSL2 on Windows //"--gpus=all", // One of the following options is required for torch multiprocessing //"--ipc=host", //"--shm-size=4gb", ], // Features to add to the dev container. More info: https://containers.dev/features. "features": { "ghcr.io/devcontainers/features/git:1": {}, "ghcr.io/devcontainers/features/git-lfs:1": {}, //"ghcr.io/devcontainers/features/nvidia-cuda:1": {}, "ghcr.io/devcontainers-extra/features/actionlint:1": {}, "ghcr.io/devcontainers-extra/features/pre-commit:2": {}, "ghcr.io/dhoeric/features/act:1": {}, "ghcr.io/devcontainers/features/docker-in-docker:2": { "version": "latest", "moby": true } }, // Set *default* container specific settings.json values on container create. "customizations": { "vscode": { "settings": { "editor.tabSize": 4, "editor.renderWhitespace": "all", "editor.formatOnSave": true, "editor.rulers": [120], "files.exclude": { "**/__pycache__": true }, "python.pythonPath": "/usr/local/bin/python", "python.defaultInterpreterPath": "/usr/local/bin/python", "python.languageServer": "Pylance", "python.analysis.autoImportCompletions": true, "python.analysis.completeFunctionParens": true, "python.analysis.autoSearchPaths": true, "python.testing.pytestArgs": ["tests"], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "code-eol.highlightNonDefault": true, "code-eol.highlightExtraWhitespace": true, "autoDocstring.docstringFormat": "google-notypes", "autoDocstring.guessTypes": true, "autoDocstring.generateDocstringOnEnter": true, "autoDocstring.startOnNewLine": true, "telemetry.telemetryLevel": "off", "[python]": { "editor.formatOnSave": true, "editor.defaultFormatter": "charliermarsh.ruff", "editor.codeActionsOnSave": { "source.organizeImports": "always", "source.fixAll": "always" } } }, // Add the IDs of extensions you want installed when the container is created. "extensions": [ "ms-python.python", "ms-python.vscode-pylance", "ms-toolsai.jupyter", "GitHub.copilot", "GitHub.copilot-chat", "github.vscode-github-actions", "SanjulaGanepola.github-local-actions", "charliermarsh.ruff", "esbenp.prettier-vscode", "ms-vscode.test-adapter-converter", "njqdev.vscode-python-typehint", "KevinRose.vsc-python-indent", "medo64.render-crlf", "shardulm94.trailing-spaces", "nhoizey.gremlins", "wayou.vscode-todo-highlight", "Gruntfuggly.todo-tree", "njpwerner.autodocstring", "rodolphebarbanneau.python-docstring-highlighter", "mechatroner.rainbow-csv", "uctakeoff.vscode-counter", "bierner.github-markdown-preview", "yahyabatulu.vscode-markdown-alert", "ms-vscode-remote.vscode-remote-extensionpack", "ms-azuretools.vscode-docker", "redhat.vscode-yaml" ] } }, // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. "postCreateCommand": "pre-commit install && pip install '.[extra,compiler,test]' -U", // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. "remoteUser": "vscode" } ================================================ FILE: .github/CODEOWNERS ================================================ * @lantiga @t-vi @lianakoleva @KaelanDt @k223kim @andyland /README.md @williamfalcon @lantiga @lianakoleva ================================================ FILE: .github/ISSUE_TEMPLATE/ask-a-question.md ================================================ --- name: Ask a Question about: Ask and answer questions related to LitGPT title: '' labels: question --- Please describe your question here. ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.yaml ================================================ name: Bug Report description: Report errors related to LitGPT title: "Description" labels: bug body: - type: markdown attributes: value: | Thank you for taking the time to report an issue. Please fill out the details below to help us resolve it. - type: textarea id: bug_description attributes: label: Bug description description: A description of the issue. placeholder: | Please provide a description of what the bug or issue is. validations: required: true - type: input attributes: label: Reproduced in studio description: > Create a new Lightning Studio with code that reproduces the issue and share the link. Also include all the relevant files and data required to reproduce shared issue. In case the code does not crash, please add assert statements to show what is the real and expected output. A simple guide on how to create such a studio can be found [here](https://www.youtube.com/watch?v=YcW-2Zt_bFg&ab_channel=LightningAI). placeholder: https://lightning.ai/... validations: required: false - type: dropdown id: operating_system attributes: label: What operating system are you using? description: If applicable, please select the operating system where you experienced this issue. options: - "Unknown" - "macOS" - "Linux" - "Windows" validations: required: true - type: textarea id: version attributes: label: LitGPT Version description: | Please provide details about your LitGPT version by running the following code in your terminal: ``` pip show litgpt | grep Version: ``` validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/feature-request.md ================================================ --- name: Suggest a Feature about: Propose a new feature or enhancement title: '' labels: enhancement --- Please describe the feature or enhancement along with the intended usecase. ================================================ FILE: .github/dependabot.yml ================================================ # Basic dependabot.yml file with # minimum configuration for two package managers version: 2 updates: # Enable version updates for python - package-ecosystem: "pip" # Look for a `requirements` in the `root` directory directory: "/" # Check for updates once a week schedule: interval: "monthly" # Labels on pull requests for version updates only labels: - "dependencies" pull-request-branch-name: # Separate sections of the branch name with a hyphen # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` separator: "-" # Allow up to 5 open pull requests for pip dependencies open-pull-requests-limit: 3 # Enable version updates for GitHub Actions - package-ecosystem: "github-actions" directory: "/" # Check for updates once a week schedule: interval: "weekly" # Labels on pull requests for version updates only labels: - "CI / actions" pull-request-branch-name: # Separate sections of the branch name with a hyphen # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` separator: "-" # Allow up to 5 open pull requests for GitHub Actions open-pull-requests-limit: 1 groups: GHA-updates: patterns: - "*" ================================================ FILE: .github/workflows/check-links.yml ================================================ name: Check hyperlinks on: push: branches: - main pull_request: branches: - main jobs: test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip pip install "mistune<3.1" # a newer version is incompatible with nbconvert pip install pytest pytest-check-links - name: Check links run: | pytest --check-links README.md --check-links-ignore "http*" pytest --check-links tutorials --check-links-ignore "http*" ================================================ FILE: .github/workflows/cpu-tests.yml ================================================ name: CPU tests on: push: branches: [main] pull_request_target: branches: [main] types: [opened, reopened, ready_for_review, labeled, synchronize] pull_request: {} # todo workflow_dispatch: {} # lock down all permissions by default permissions: contents: read # needed to check out code checks: write # needed for test results pull-requests: read # needed for PR metadata actions: read # needed to use actions security-events: none statuses: write # needed to update commit status concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }} defaults: run: shell: bash env: HF_HOME: .cache-HF # Define HF_HOME for caching TRANSFORMERS_CACHE: .cache-HF/transformers DATASETS_CACHE: .cache-HF/datasets HF_DATASETS_CACHE: .cache-HF/datasets TORCH_URL: "https://download.pytorch.org/whl/cpu/" jobs: testing-imports: runs-on: ${{ matrix.os }} if: github.event_name != 'pull_request_target' strategy: fail-fast: false matrix: os: ["ubuntu-22.04", "ubuntu-24.04", "macOS-14", "windows-2022"] python-version: ["3.10"] timeout-minutes: 10 steps: - name: Checkout generic uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install minimal dependencies run: | pip install . -U --extra-index-url="${TORCH_URL}" pip list - name: Testing package imports # make sure all modules are still importable with only the minimal dependencies available run: | modules=$( find litgpt -type f -name "*.py" | \ sed 's/\.py$//' | sed 's/\//./g' | \ sed 's/.__init__//g' | xargs -I {} echo "import {};" ) echo "$modules" python -c "$modules" pytester: # Route PRs based on contributor type to avoid duplicate runs: # - Collaborators: use pull_request (tests workflow changes from PR) # - External forks: use pull_request_target (uses trusted workflow from main) # - Always run for push to main and workflow_dispatch if: | (github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) || (github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) || (github.event_name != 'pull_request' && github.event_name != 'pull_request_target') runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: ["ubuntu-22.04"] python-version: ["3.10", "3.11", "3.12", "3.13"] requires: ["latest"] include: - { os: "ubuntu-22.04", python-version: "3.10", requires: "oldest" } - { os: "windows-2022", python-version: "3.10", requires: "latest" } - { os: "macOS-14", python-version: "3.10", requires: "latest" } timeout-minutes: 35 steps: - name: Checkout generic uses: actions/checkout@v6 if: github.event_name != 'pull_request_target' - name: Checkout for `pull_request_target` uses: actions/checkout@v6 if: github.event_name == 'pull_request_target' with: ref: ${{ github.event.pull_request.head.sha }} - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} cache-dependency-path: pyproject.toml cache: "pip" # Add caching for HF models and tokenizers - name: HF cache uses: actions/cache@v5 continue-on-error: true with: path: .cache-HF key: hf-cache_${{ runner.os }}-py${{ matrix.python-version }} restore-keys: | hf-cache_${{ runner.os }}-py${{ matrix.python-version }} hf-cache_${{ runner.os }}- hf-cache_ - name: Set min. dependencies if: matrix.requires == 'oldest' run: | pip install 'lightning-utilities[cli]>=0.15.1' python -m lightning_utilities.cli requirements set-oldest --req_files=pyproject.toml - name: Install dependencies run: | pip install '.[extra,compiler,test]' -U --upgrade-strategy eager --extra-index-url="${TORCH_URL}" pip list - name: Run tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: pytest -v litgpt/ tests/ --timeout=180 --durations=100 - name: Show cache run: | pip install -q py-tree python -m py_tree -d 1 .cache-HF testing-guardian: runs-on: ubuntu-latest needs: [pytester, testing-imports] if: | (github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) || (github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) steps: - run: echo "${{ needs.pytester.result }}" - name: failing... if: needs.pytester.result == 'failure' run: exit 1 - name: cancelled or skipped... if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) timeout-minutes: 1 run: sleep 90 ================================================ FILE: .github/workflows/mkdocs-deploy.yml ================================================ name: Deploy MkDocs on: push: branches: [main] permissions: contents: write jobs: deploy: runs-on: ubuntu-24.04 steps: # Step 1: Checkout the repository - uses: actions/checkout@v6 # Step 2: Set up Python - uses: actions/setup-python@v6 with: python-version: "3.x" cache: "pip" # Step 3: Install MkDocs and dependencies - run: pip install mkdocs mkdocs-material mkdocs-pagetree-plugin # Step 4: Deploy to GitHub Pages - run: | mkdir -p gh-pages/docs cp -r tutorials/* gh-pages/docs cd gh-pages mv docs/mkdocs.yml mkdocs.yml echo "{{ pagetree }}" > docs/index.md mkdocs gh-deploy --force ================================================ FILE: .github/workflows/publish-pkg.yml ================================================ # To create a release, create a tag and push it to GitHub: #git tag -a "v0.0.1-beta" -m "beta version testing" #git push --tags # https://dev.to/iamtekson/publish-package-to-pypi-and-release-new-version-using-github-actions-108k name: Publish LitGPT to PyPI on: push: tags: - "v*" jobs: build-n-publish: name: Build and publish to PyPI runs-on: ubuntu-latest environment: name: pypi url: https://pypi.org/p/litgpt permissions: id-token: write steps: - name: Checkout source uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: python-version: "3.x" cache: "pip" - name: Build source and wheel distributions run: | python -m pip install --upgrade build twine pip install importlib_metadata==7.2.1 python -m build twine check --strict dist/* - name: Publish distribution to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .gitignore ================================================ .ipynb_checkpoints/ __pycache__ .idea .DS_Store *.egg-info build dist .venv .venv/ .vscode uv.lock # data data datasets !litgpt/data !tests/data checkpoints out wandb events.out.tfevents* # test artifacts from tests/test_readme.py **/custom_finetuning_dataset.json client.py **/custom_texts/ ================================================ FILE: .lightning/workflows/tests.yaml ================================================ trigger: push: branches: ["main"] pull_request: branches: ["main"] image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.8.1-cudnn-fe1.15.0-py3.12-pt_2.8.0-dev" machine: "L4_X_2" interruptible: "true" timeout: "45" # minutes parametrize: matrix: dependency: ["", "compiler"] include: [] exclude: [] env: SKIP_WITH_CI: "1" # skip single tests with CI NCCL_DEBUG: "INFO" CUBLAS_WORKSPACE_CONFIG: ":4096:8" NCCL_IGNORE_DISABLED_P2P: "1" TORCH_VERSION: "2.8.0" RUN_ONLY_CUDA_TESTS: "1" # run CUDA tests only run: | whereis nvidia nvidia-smi python --version pip --version pip list set -ex echo "Install uv and create virtual environment" curl -LsSf https://astral.sh/uv/install.sh | sh [ -f "$HOME/.local/bin/env" ] && . "$HOME/.local/bin/env" export PATH="$HOME/.local/bin:$PATH" uv venv .venv --system-site-packages . .venv/bin/activate hash -r uv pip install -q '.[extra,test]' "torch==${TORCH_VERSION}" cffi -U if [ "${dependency}" == "compiler" ]; then uv pip uninstall torchvision torchaudio uv pip install -q '.[compiler,extra,test]' "torch==${TORCH_VERSION}" python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'" python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'" fi uv pip list python -c "import torch ; gpus = torch.cuda.device_count() ; assert gpus >= 2, f'GPU: {gpus}'" python -c "from torch import __version__ as ver ; assert str(ver).split('+')[0] == '${TORCH_VERSION}', f'PyTorch: installed {ver} but expected ${TORCH_VERSION}'" pytest -v --durations=100 wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh PL_RUN_STANDALONE_TESTS=1 bash run_standalone_tests.sh "tests" if [ "${dependency}" == "compiler" ]; then uv pip uninstall lightning-thunder transformers # install thunder from source, so that, thunder.tests will be available uv pip install -U "lightning-thunder[test] @ git+https://github.com/Lightning-AI/lightning-thunder.git" "torch==${TORCH_VERSION}" # Pin transformers to match thunder's test_networks.py requirements # See: https://github.com/Lightning-AI/lightning-thunder/blob/main/requirements/test.txt # Get transformers version from thunder requirements TRANSFORMERS_VERSION=$(curl -fsSL https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/main/requirements/test.txt \ | grep '^transformers==' \ | cut -d'=' -f3 \ | cut -d'#' -f1 \ | xargs) if [ -z "${TRANSFORMERS_VERSION}" ]; then echo "Error: Could not determine transformers version from lightning-thunder requirements" exit 1 fi uv pip install transformers==${TRANSFORMERS_VERSION} # without env var, it filters out all tests RUN_ONLY_CUDA_TESTS=0 pytest tests/ext_thunder/test_thunder_networks.py -v fi ================================================ FILE: .pre-commit-config.yaml ================================================ # Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. default_language_version: python: python3 ci: autofix_prs: true autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" autoupdate_schedule: quarterly # submodules: true repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace exclude: README.md - id: check-yaml - id: check-toml #- id: check-docstring-first #- id: check-executables-have-shebangs - id: check-case-conflict - id: check-added-large-files args: ["--maxkb=250", "--enforce-all"] - id: detect-private-key - repo: https://github.com/codespell-project/codespell rev: v2.4.1 hooks: - id: codespell additional_dependencies: [tomli] args: ["--write-changes"] exclude: pyproject.toml #- repo: https://github.com/crate-ci/typos # rev: dictgen-v0.3.1 # hooks: # - id: typos # args: [] # empty to do not write fixes # exclude: pyproject.toml #- repo: https://github.com/executablebooks/mdformat # rev: 0.7.21 # hooks: # - id: mdformat # args: ["--number"] # additional_dependencies: # - mdformat-gfm # - mdformat-black # - mdformat_frontmatter - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 hooks: - id: prettier files: \.(json|yml|yaml|toml) # https://prettier.io/docs/en/options.html#print-width args: ["--print-width=140"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.10 hooks: - id: ruff args: ["--fix"] - id: ruff-format - id: ruff - repo: https://github.com/tox-dev/pyproject-fmt rev: v2.11.1 hooks: - id: pyproject-fmt additional_dependencies: [tox] - repo: https://github.com/abravalheri/validate-pyproject rev: v0.24.1 hooks: - id: validate-pyproject ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, you can cite it as shown below." title: "LitGPT" abstract: "20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale." date-released: 2023-03-22 authors: - name: "The Lightning AI team" license: "Apache-2.0" url: "https://github.com/Lightning-AI/litgpt" ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2023] Lightning AI Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
# ⚡ LitGPT **20+ high-performance LLMs with recipes to pretrain, finetune, and deploy at scale.**
✅ From scratch implementations      ✅ No abstractions         ✅ Beginner friendly
   ✅ Flash attention                   ✅ FSDP                    ✅ LoRA, QLoRA, Adapter
✅ Reduce GPU memory (fp4/8/16/32)   ✅ 1-1000+ GPUs/TPUs       ✅ 20+ LLMs         
--- ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-lightning) ![cpu-tests](https://github.com/lightning-AI/lit-stablelm/actions/workflows/cpu-tests.yml/badge.svg) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-stablelm/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439)](https://discord.gg/VptPCZkGNa)

Quick startModelsFinetuneDeployAll workflowsFeaturesRecipes (YAML)Lightning AITutorials

  Get started  
# Looking for GPUs? Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) - purpose-built for PyTorch and PyTorch Lightning. - [GPUs](https://lightning.ai/pricing?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) from $0.19. - [Clusters](https://lightning.ai/clusters?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): frontier-grade training/inference clusters. - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you debug, tune and vibe train. - [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you optimize, and deploy models. - [Notebooks](https://lightning.ai/notebooks?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Persistent GPU workspaces where AI helps you code and analyze. - [Inference](https://lightning.ai/deploy?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Deploy models as inference APIs. # Finetune, pretrain, and inference LLMs Lightning fast ⚡⚡ Every LLM is implemented from scratch with **no abstractions** and **full control**, making them blazing fast, minimal, and performant at enterprise scale. ✅ **Enterprise ready -** Apache 2.0 for unlimited enterprise use.
✅ **Developer friendly -** Easy debugging with no abstraction layers and single file implementations.
✅ **Optimized performance -** Models designed to maximize performance, reduce costs, and speed up training.
✅ **Proven recipes -** Highly-optimized training/finetuning recipes tested at enterprise scale.
  # Quick start Install LitGPT ``` pip install 'litgpt[extra]' ``` Load and use any of the [20+ LLMs](#choose-from-20-llms): ```python from litgpt import LLM llm = LLM.load("microsoft/phi-2") text = llm.generate("Fix the spelling: Every fall, the family goes to the mountains.") print(text) # Corrected Sentence: Every fall, the family goes to the mountains. ```   ✅ Optimized for fast inference
✅ Quantization
✅ Runs on low-memory GPUs
✅ No layers of internal abstractions
✅ Optimized for production scale
Advanced install options Install from source: ```bash git clone https://github.com/Lightning-AI/litgpt cd litgpt # if using uv uv sync --all-extras # if using pip pip install -e ".[extra,compiler,test]" ```
[Explore the full Python API docs](tutorials/python-api.md).   --- # Choose from 20+ LLMs Every model is written from scratch to maximize performance and remove layers of abstraction: | Model | Model size | Author | Reference | |----|----|----|----| | Llama 3, 3.1, 3.2, 3.3 | 1B, 3B, 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) | | Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | ... | ... | ... | ... |
See full list of 20+ LLMs   #### All models | Model | Model size | Author | Reference | |----|----|----|----| | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | | Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | Gemma 2 | 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) | | Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) | | Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | | Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | | Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) | | Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743) | | Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233) | | Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) | | Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) | | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) | | QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | **Tip**: You can list all available models by running the `litgpt download list` command.
  --- # Workflows

FinetunePretrainContinued pretrainingEvaluateDeployTest

  Use the command line interface to run advanced workflows such as pretraining or finetuning on your own data. ## All workflows After installing LitGPT, select the model and workflow to run (finetune, pretrain, evaluate, deploy, etc...): ```bash # litgpt [action] [model] litgpt serve meta-llama/Llama-3.2-3B-Instruct litgpt finetune meta-llama/Llama-3.2-3B-Instruct litgpt pretrain meta-llama/Llama-3.2-3B-Instruct litgpt chat meta-llama/Llama-3.2-3B-Instruct litgpt evaluate meta-llama/Llama-3.2-3B-Instruct ```   ---- ## Finetune an LLM
Run on Studios
  Finetuning is the process of taking a pretrained AI model and further training it on a smaller, specialized dataset tailored to a specific task or application.   ```bash # 0) setup your dataset curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json # 1) Finetune a model (auto downloads weights) litgpt finetune microsoft/phi-2 \ --data JSON \ --data.json_path my_custom_dataset.json \ --data.val_split_fraction 0.1 \ --out_dir out/custom-model # 2) Test the model litgpt chat out/custom-model/final # 3) Deploy the model litgpt serve out/custom-model/final ``` [Read the full finetuning docs](tutorials/finetune.md)   ---- ## Deploy an LLM
Deploy on Studios
  Deploy a pretrained or finetune LLM to use it in real-world applications. Deploy, automatically sets up a web server that can be accessed by a website or app. ```bash # deploy an out-of-the-box LLM litgpt serve microsoft/phi-2 # deploy your own trained model litgpt serve path/to/microsoft/phi-2/checkpoint ```
Show code to query server:   Test the server in a separate terminal and integrate the model API into your AI product: ```python # 3) Use the server (in a separate Python session) import requests, json response = requests.post( "http://127.0.0.1:8000/predict", json={"prompt": "Fix typos in the following sentence: Example input"} ) print(response.json()["output"]) ```
[Read the full deploy docs](tutorials/deploy.md).   ---- ## Evaluate an LLM Evaluate an LLM to test its performance on various tasks to see how well it understands and generates text. Simply put, we can evaluate things like how well would it do in college-level chemistry, coding, etc... (MMLU, Truthful QA, etc...) ```bash litgpt evaluate microsoft/phi-2 --tasks 'truthfulqa_mc2,mmlu' ``` [Read the full evaluation docs](tutorials/evaluation.md).   ---- ## Test an LLM
Run on Studios
  Test how well the model works via an interactive chat. Use the `chat` command to chat, extract embeddings, etc... Here's an example showing how to use the Phi-2 LLM: ```bash litgpt chat microsoft/phi-2 >> Prompt: What do Llamas eat? ```
Full code:   ```bash # 1) List all supported LLMs litgpt download list # 2) Use a model (auto downloads weights) litgpt chat microsoft/phi-2 >> Prompt: What do Llamas eat? ``` The download of certain models requires an additional access token. You can read more about this in the [download](tutorials/download_model_weights.md#specific-models-and-access-tokens) documentation.
[Read the full chat docs](tutorials/inference.md).   ---- ## Pretrain an LLM
Run on Studios
  Pretraining is the process of teaching an AI model by exposing it to a large amount of data before it is fine-tuned for specific tasks.
Show code:   ```bash mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt # 1) Download a tokenizer litgpt download EleutherAI/pythia-160m \ --tokenizer_only True # 2) Pretrain the model litgpt pretrain EleutherAI/pythia-160m \ --tokenizer_dir EleutherAI/pythia-160m \ --data TextFiles \ --data.train_data_path "custom_texts/" \ --train.max_tokens 10_000_000 \ --out_dir out/custom-model # 3) Test the model litgpt chat out/custom-model/final ```
[Read the full pretraining docs](tutorials/pretrain.md)   ---- ## Continue pretraining an LLM
Run on Studios
  Continued pretraining is another way of finetuning that specializes an already pretrained model by training on custom data:
Show code:   ```bash mkdir -p custom_texts curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt # 1) Continue pretraining a model (auto downloads weights) litgpt pretrain EleutherAI/pythia-160m \ --tokenizer_dir EleutherAI/pythia-160m \ --initial_checkpoint_dir EleutherAI/pythia-160m \ --data TextFiles \ --data.train_data_path "custom_texts/" \ --train.max_tokens 10_000_000 \ --out_dir out/custom-model # 2) Test the model litgpt chat out/custom-model/final ```
[Read the full continued pretraining docs](tutorials/pretrain.md#continued-pretraining-on-custom-data)   ---- # State-of-the-art features ✅ State-of-the-art optimizations: Flash Attention v2, multi-GPU support via fully-sharded data parallelism, [optional CPU offloading](tutorials/oom.md#do-sharding-across-multiple-gpus), and [TPU and XLA support](extensions/xla).
✅ [Pretrain](tutorials/pretrain.md), [finetune](tutorials/finetune.md), and [deploy](tutorials/inference.md)
✅ Reduce compute requirements with low-precision settings: FP16, BF16, and FP16/FP32 mixed.
✅ Lower memory requirements with [quantization](tutorials/quantize.md): 4-bit floats, 8-bit integers, and double quantization.
✅ [Configuration files](config_hub) for great out-of-the-box performance.
✅ Parameter-efficient finetuning: [LoRA](tutorials/finetune_lora.md), [QLoRA](tutorials/finetune_lora.md), [Adapter](tutorials/finetune_adapter.md), and [Adapter v2](tutorials/finetune_adapter.md).
✅ [Exporting](tutorials/convert_lit_models.md) to other popular model weight formats.
✅ Many popular datasets for [pretraining](tutorials/pretrain.md) and [finetuning](tutorials/prepare_dataset.md), and [support for custom datasets](tutorials/prepare_dataset.md#preparing-custom-datasets-for-instruction-finetuning).
✅ Readable and easy-to-modify code to experiment with the latest research ideas.
  --- # Training recipes LitGPT comes with validated recipes (YAML configs) to train models under different conditions. We've generated these recipes based on the parameters we found to perform the best for different training conditions. Browse all training recipes [here](config_hub). ### Example ```bash litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml ```
✅ Use configs to customize training Configs let you customize training for all granular parameters like: ```yaml # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama2-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true ... ```
✅ Example: LoRA finetuning config   ```yaml # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama2-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: float, default: 0.0003) learning_rate: 0.0002 # (type: float, default: 0.02) weight_decay: 0.0 # (type: float, default: 0.9) beta1: 0.9 # (type: float, default: 0.95) beta2: 0.95 # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 ```
✅ Override any parameter in the CLI: ```bash litgpt finetune \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \ --lora_r 4 ```
  ---- # Project highlights LitGPT powers many great AI projects, initiatives, challenges and of course enterprises. Please submit a pull request to be considered for a feature.
📊 SAMBA: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling The [Samba](https://github.com/microsoft/Samba) project by researchers at Microsoft is built on top of the LitGPT code base and combines state space models with sliding window attention, which outperforms pure state space models.
🏆 NeurIPS 2023 Large Language Model Efficiency Challenge: 1 LLM + 1 GPU + 1 Day The LitGPT repository was the official starter kit for the [NeurIPS 2023 LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io), which is a competition focused on finetuning an existing non-instruction tuned LLM for 24 hours on a single GPU.
🦙 TinyLlama: An Open-Source Small Language Model LitGPT powered the [TinyLlama project](https://github.com/jzhang38/TinyLlama) and [TinyLlama: An Open-Source Small Language Model](https://arxiv.org/abs/2401.02385) research paper.
🍪 MicroLlama: MicroLlama-300M [MicroLlama](https://github.com/keeeeenw/MicroLlama) is a 300M Llama model pretrained on 50B tokens powered by TinyLlama and LitGPT.
🔬 Pre-training Small Base LMs with Fewer Tokens The research paper ["Pre-training Small Base LMs with Fewer Tokens"](https://arxiv.org/abs/2404.08634), which utilizes LitGPT, develops smaller base language models by inheriting a few transformer blocks from larger models and training on a tiny fraction of the data used by the larger models. It demonstrates that these smaller models can perform comparably to larger models despite using significantly less training data and resources.
  ---- # Community We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment. - [Request a feature](https://github.com/Lightning-AI/litgpt/issues) - [Submit your first contribution](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) - [Join our Discord](https://discord.gg/VptPCZkGNa)   # Tutorials 🚀 [Get started](tutorials/0_to_litgpt.md)
⚡️ [Finetuning, incl. LoRA, QLoRA, and Adapters](tutorials/finetune.md)
🤖 [Pretraining](tutorials/pretrain.md)
💬 [Model evaluation](tutorials/evaluation.md)
📘 [Supported and custom datasets](tutorials/prepare_dataset.md)
🧹 [Quantization](tutorials/quantize.md)
🤯 [Tips for dealing with out-of-memory (OOM) errors](tutorials/oom.md)
🧑🏽‍💻 [Using cloud TPUs](extensions/xla)
  ---- ### Acknowledgments This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) ⚡**. - [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT) - [@EleutherAI](https://github.com/EleutherAI) for [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) and the [Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) - [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA) - [@tridao](https://github.com/tridao) for [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) ### License LitGPT is released under the [Apache 2.0](https://github.com/Lightning-AI/litgpt/blob/main/LICENSE) license. ### Citation If you use LitGPT in your research, please cite the following work: ```bibtex @misc{litgpt-2023, author = {Lightning AI}, title = {LitGPT}, howpublished = {\url{https://github.com/Lightning-AI/litgpt}}, year = {2023}, } ```   ================================================ FILE: config_hub/finetune/README.md ================================================ ## Config files The table below lists the performances you can expect from the provided config files. Note that you can achieve lower memory consumption by lowering the micro batch size as needed. In addition, you can lower the rank (`lora_r`) in the LoRA configuration files and disable LoRA for certain layers (for example, setting `lora_projection` and other LoRA layer-specific parameters to `false`). For more information, see the [Dealing with out-of-memory (OOM) errors](../../tutorials/oom.md) on lowering the memory requirements. The "Cost" column refers to the on-demand compute cost on [Lightning AI Studios where these benchmarks were executed](https://lightning.ai/lightning-ai/studios/automated-benchmarks-for-litgpt). All experiments were conducted using bfloat-16 precision on the Alpaca2k dataset. The "Multitask score" refers to [MMLU](https://arxiv.org/abs/2009.03300).   | Config | Model | Epochs | Max seq length | Micro batch size | Machine | Training runtime | Cost | Peak memory | Validation loss | Validation perplexity | Multitask score (MMLU) | | --------------------------------- | ---------------------- | ------ | -------------- | ---------------- | ------- | ---------------- | ---- | ----------- | --------------- | --------------------- | --------------- | | falcon-7b/lora.yaml | falcon-7b | 4 | 512 | 1 | 1xA10G | 24.84 min | $0.7 | 16.69 GB | 0.945 | 2.573 | 26.2% | | falcon-7b/lora.yaml | falcon-7b | 4 | 512 | 1 | 4xA10G | 24.94 min | $2.0 | 16.69 GB | 0.945 | 2.573 | 26.4% | | falcon-7b/qlora.yaml | falcon-7b | 4 | 512 | 1 | 1xA10G | 50.85 min | $1.5 | 9.44 GB | 0.993 | 2.699 | 26.3% | | | | | | | | | | | | | | | gemma-2b/full.yaml | gemma-2b | 1 | 512 | 1 | 4xA10G | 14.06 min | $1.1 | 17.43 GB | 1.021 | 2.777 | 32.4% | | gemma-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 9.41 min | $0.3 | 12.62 GB | 0.981 | 2.666 | 34.4% | | gemma-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 4xA10G | 9.41 min | $0.8 | 12.62 GB | 0.981 | 2.667 | 34.0% | | gemma-2b/qlora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 12.91 min | $0.4 | 11.58 GB | 1.085 | 2.959 | 36.4% | | | | | | | | | | | | | | | gemma-7b/lora.yaml | gemma-7b | 2 | 512 | 1 | 1xA10G | OOM | OOM | OOM | OOM | OOM | | | gemma-7b/lora.yaml | gemma-7b | 2 | 512 | 1 | 4xA10G | OOM | OOM | OOM | OOM | OOM | | | gemma-7b/qlora.yaml | gemma-7b | 2 | 512 | 1 | 1xA10G | 43.58 min | $1.3 | 17.18 GB | 0.973 | 2.646 | 62.45% | | | | | | | | | | | | | | | gemma2-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 11.96 min | $0.4 | 14.31 GB | 0.951 | 2.589 | 23.84% | | gemma2b/qlora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 16.06 min | $0.5 | 13.52 GB | 0.983 | 2.673 | 24.12% | | | | | | | | | | | | | | | gemma2-9b/lora.yaml | gemma-2-9b | 2 | 512 | 1 | 1xA10G | OOM | OOM | OOM | OOM | OOM | | | gemma2-9b/lora.yaml | gemma-2-9b | 2 | 512 | 1 | 4xA10G | OOM | OOM | OOM | OOM | OOM | | | gemma2-9b/qlora.yaml | gemma-2-9b | 2 | 512 | 1 | 1xA10G | 50.01 min | $4.0 | 20.92 GB | 0.852 | 2.345 | 24.2% | | | | | | | | | | | | | | | llama-2-7b/full.yaml | llama-2-7b | 1 | 512 | 4 | 4xA10G | OOM | OOM | OOM | OOM | OOM | | | llama-2-7b/lora.yaml | llama-2-7b | 4 | 512 | 2 | 1xA10G | 32.82 min | $1.0 | 19.77 GB | 0.802 | 2.230 | 40.3% | | llama-2-7b/lora.yaml | llama-2-7b | 4 | 512 | 2 | 4xA10G | 32.83 min | $2.6 | 19.77 GB | 0.802 | 2.229 | 40.2% | | llama-2-7b/qlora.yaml | llama-2-7b | 4 | 512 | 2 | 1xA10G | 45.67 min | $1.4 | 13.68 GB | 0.814 | 2.258 | 38.6% | | | | | | | | | | | | | | | llama-3-8b/full.yaml | llama-3-8b | 1 | 512 | 4 | 4xA10G | OOM | OOM | OOM | OOM | OOM | | | llama-3-8b/lora.yaml | llama-3-8b | 2 | 512 | 1 | 1xA10G | 14.79 min | $0.4 | 19.73 GB | 0.888 | 2.431 | 62.4% | | llama-3-8b/lora.yaml | llama-3-8b | 2 | 512 | 1 | 4xA10G | 14.88 min | $1.2 | 19.73 GB | 0.889 | 2.432 | 62.5% | | llama-3-8b/qlora.yaml | llama-3-8b | 2 | 512 | 2 | 1xA10G | 22.24 min | $0.7 | 17.41 GB | 0.939 | 2.558 | 62.2% | | | | | | | | | | | | | | | llama-3.1-8b/full.yaml | llama-3.1-8b | 1 | 512 | 4 | 1xA10G | OOM | OOM | OOM | OOM | OOM | OOM | | llama-3.1-8b/lora.yaml | llama-3.1-8b | 2 | 512 | 1 | 1xA10G | 13.36 min | $1.1 | 19.73 GB | 0.878 | 2.406 | xx.xx | | llama-3.1-8b/qlora.yaml | llama-3.1-8b | 2 | 512 | 2 | 1xA10G | 21.81 min | $0.7 | 17.41 GB | 0.928 | 2.529 | xx.xx | | | | | | | | | | | | | | | llama-3.2-1b/full.yaml | llama-3.2-1b | 1 | 512 | 4 | 1xA10G | 2.01 min | $0.1 | 8.70 GB | 1.442 | 4.229 | 38.21% | | llama-3.2-1b/lora.yaml | llama-3.2-1b | 2 | 512 | 1 | 1xA10G | 4.17 min | $0.4 | 4.49 GB | 1.114 | 3.046 | 36.87% | | llama-3.2-1b/qlora.yaml | llama-3.2-1b | 2 | 512 | 2 | 1xA10G | 6.20 min | $0.6 | 5.53 GB | 1.201 | 3.322 | 36.49% | | | | | | | | | | | | | | | llama-3.2-3b/full.yaml | llama-3.2-3b | 1 | 512 | 4 | 1xA10G | 4.71 min | $0.4 | 16.51 GB | 1.255 | 3.509 | 54.69% | | llama-3.2-3b/lora.yaml | llama-3.2-3b | 2 | 512 | 1 | 1xA10G | 8.31 min | $0.8 | 9.67 GB | 0.973 | 2.647 | 54.77% | | llama-3.2-3b/qlora.yaml | llama-3.2-3b | 2 | 512 | 2 | 1xA10G | 14.89 min | $1.4 | 10.30 GB | 1.031 | 2.804 | 55.08% | | | | | | | | | | | | | | | mistral-7b-v0.2/lora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 1xA10G | 31.00 min | $0.9 | 20.66 GB | 0.801 | 2.228 | 55.7% | | mistral-7b-v0.2/lora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 4xA10G | 31.00 min | $2.5 | 20.66 GB | 0.802 | 2.229 | 55.5% | | mistral-7b-v0.2/qlora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 1xA10G | 44.75 min | $1.3 | 14.29 GB | 0.813 | 2.255 | 56.5% | | | | | | | | | | | | | | | mistral-7b/lora.yaml | mistral-7b | 4 | 512 | 2 | 1xA10G | 31.01 min | $0.9 | 20.66 GB | 0.794 | 2.211 | 57.9% | | mistral-7b/lora.yaml | mistral-7b | 4 | 512 | 2 | 4xA10G | 31.03 min | $2.5 | 20.66 GB | 0.796 | 2.218 | 57.9% | | mistral-7b/qlora.yaml | mistral-7b | 4 | 512 | 2 | 1xA10G | 44.75 min | $1.3 | 14.29 GB | 0.803 | 2.231 | 57.9% | | | | | | | | | | | | | | | phi-2/full.yaml | phi-2 | 1 | 512 | 4 | 4xA10G | 11.87 min | $1.0 | 14.44 GB | 1.305 | 3.688 | 38.4% | | phi-2/lora.yaml | phi-2 | 1 | 512 | 4 | 1xA10G | 3.78 min | $0.1 | 13.98 GB | 0.819 | 2.269 | 53.0% | | phi-2/lora.yaml | phi-2 | 1 | 512 | 4 | 4xA10G | 3.78 min | $0.3 | 13.98 GB | 0.820 | 2.271 | 52.4% | | phi-2/qlora.yaml | phi-2 | 1 | 512 | 4 | 1xA10G | 4.51 min | $0.1 | 14.27 GB | 0.837 | 2.310 | 52.3% | | | | | | | | | | | | | | | phi-3/full.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 6.93 min | $0.2 | 17.01 GB | 0.714 | 2.043 | 69.81% | | phi-3/lora.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 6.46 min | $0.2 | 19.75 GB | 0.707 | 2.028 | 69.70% | | phi-3/qlora.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 7.47 min | $0.2 | 19.13 GB | 0.729 | 2.074 | 68.96% | | | | | | | | | | | | | | | stablelm-base-alpha-3b/full.yaml | stablelm-base-alpha-3b | 1 | 512 | 1 | 4xA10G | 70.13 min | $5.6 | 21.23 GB | 1.513 | 4.540 | 23.2% | | stablelm-base-alpha-3b/lora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 1xA10G | 13.07 min | $0.4 | 8.58 GB | 1.361 | 3.900 | 25.9% | | stablelm-base-alpha-3b/lora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 4xA10G | 13.16 min | $1.1 | 8.58 GB | 1.362 | 3.906 | 25.9% | | stablelm-base-alpha-3b/qlora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 1xA10G | 25.86 min | $0.8 | 5.24 GB | 1.388 | 4.009 | 26.1% | | | | | | | | | | | | | | | tiny-llama/full.yaml | tiny-llama | 1 | 512 | 4 | 1xA10G | 2.58 min | $0.1 | 14.10 GB | 1.088 | 2.968 | 24.6% | | tiny-llama/full.yaml | tiny-llama | 1 | 512 | 4 | 4xA10G | 2.57 min | $0.2 | 14.10 GB | 1.088 | 2.968 | 24.5% | | tiny-llama/lora.yaml | tiny-llama | 3 | 512 | 8 | 1xA10G | 8.09 min | $0.2 | 13.50 GB | 1.039 | 2.826 | 25.5% | | tiny-llama/qlora.yaml | tiny-llama | 3 | 512 | 8 | 1xA10G | 8.70 min | $0.3 | 16.24 GB | 1.056 | 2.874 | 25.3% | *OOM = Out of memory   ## Extending the context length If you require a longer sequence length than the one used in a given config file, you can either edit the `max_seq_length` in the config file or pass an additional argument when running the finetuning command, for example, `--max_seq_length 4096` to override the sequence length provided in the config file.   ## Training on GPUs without bfloat16 support If you are training on GPUs without bfloat-16 support, you need to change the `precision` option to `16-true` (16-bit floating point precision) or `16-mixed` (16/32-bit mixed precision) training: ```bash litgpt finetune lora \ --config config_hub/finetune/phi-2/lora.yaml \ --precision 16-true ``` or ```bash litgpt finetune lora \ --config config_hub/finetune/phi-2/lora.yaml \ --precision 16-mixed ``` Note that `16-true` is more compute and memory-efficient, but it can sometimes lead to training convergence issues. In this case, it's recommended to use `16-mixed`.   ## Multi-GPU experiments All runs are single-GPU experiments, use `--devices 4` to utilize more than one GPU: ```bash litgpt finetune lora \ --config config_hub/finetune/phi-2/lora.yaml \ --devices 4 ``` ================================================ FILE: config_hub/finetune/falcon-7b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/tiiuae/falcon-7b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-falcon-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/falcon-7b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/tiiuae/falcon-7b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-falcon-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma-2b/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/full-gemma-2b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 4 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 16 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 100 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: 50 # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma-2b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-gemma-2b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma-2b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-gemma-2b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma-7b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-7b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-gemma-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma-7b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-7b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-gemma-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma2-2b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2-2b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-gemma-2-2b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma2-2b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2-2b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-gemma-2-2b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma2-9b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2-9b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-gemma-2-9b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/gemma2-9b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/google/gemma-2-9b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-gemma-2-9b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 16 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.1 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 6 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-2-7b/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-llama2-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 4 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 64 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 25 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-2-7b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-llama2-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-2-7b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama2-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3-8b/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-llama-3-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 4 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 64 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 25 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3-8b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-llama-3-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3-8b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama3-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.1-8b/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-llama-3.1-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 4 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 64 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 25 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.1-8b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-llama-3.1-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.1-8b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama3.1-8b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-1B/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-llama-3.2-1B # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) # resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 64 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 25 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-1B/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-llama-3.2-1B # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-1B/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama3.2-1b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-3B/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-llama-3.2-3B # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) # resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 64 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 25 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-3B/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-llama-3.2-3B # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/llama-3.2-3B/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-llama3.2-3b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/mistral-7b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-mistral-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/mistral-7b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-mistral-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/mistral-7b-v0.2/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-mistral-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/mistral-7b-v0.2/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-mistral-7b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 2 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-2/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/phi-2 # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-phi-2 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 2 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: 100 # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-2/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/phi-2 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-phi-2 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-2/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/phi-2 # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-phi-2 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-3/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct # Directory in which to save checkpoints and logs. (type: , default: out/finetune/full) out_dir: out/finetune/full-phi-3 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use (type: Union[int, str], default: 1) devices: 1 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 1) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 600) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-3/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-phi-3 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/phi-3/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-phi-3 # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # The LoRA rank. (type: int, default: 8) lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/stablelm-base-alpha-3b/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/full-stablelm-base-alpha-3b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 2 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 1000 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/stablelm-base-alpha-3b/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-stablelm-base-alpha-3b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/stablelm-base-alpha-3b/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-stablelm-base-alpha-3b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: false # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: false # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: false # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: false # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.05 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 download_dir: data/alpaca2k # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 200 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 4 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run (type: Optional[int], default: null) max_steps: # Limits the length of samples (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/tiny-llama/full.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/full-tiny-llama-1.1b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 32 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 1000 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 25 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/tiny-llama/lora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/lora-tiny-llama-1.1b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 8 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 3 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/finetune/tiny-llama/qlora.yaml ================================================ # The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T # Directory in which to save checkpoints and logs. (type: , default: out/lora) out_dir: out/finetune/qlora-tiny-llama-1.1b # The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-true # If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) quantize: bnb.nf4 # How many devices/GPUs to use. (type: Union[int, str], default: 1) devices: 1 # How many nodes to use. (type: int, default: 1) num_nodes: 1 # The LoRA rank. (type: int, default: 8) lora_r: 32 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 # The LoRA dropout value. (type: float, default: 0.05) lora_dropout: 0.05 # Whether to apply LoRA to the query weights in attention. (type: bool, default: True) lora_query: true # Whether to apply LoRA to the key weights in attention. (type: bool, default: False) lora_key: true # Whether to apply LoRA to the value weights in attention. (type: bool, default: True) lora_value: true # Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) lora_projection: true # Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) lora_mlp: true # Whether to apply LoRA to output head in GPT. (type: bool, default: False) lora_head: true # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: class_path: litgpt.data.Alpaca2k init_args: mask_prompt: false val_split_fraction: 0.03847 prompt_style: alpaca ignore_index: -100 seed: 42 num_workers: 4 # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 800 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) global_batch_size: 8 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 8 # Number of iterations with learning rate warmup active (type: int, default: 100) lr_warmup_steps: 10 # Number of epochs to train on (type: Optional[int], default: 5) epochs: 3 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) tie_embeddings: # (type: Optional[float], default: null) max_norm: # (type: float, default: 6e-05) min_lr: 6.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 100) interval: 100 # Number of tokens to generate (type: Optional[int], default: 100) max_new_tokens: 100 # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: true # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 1337) seed: 1337 # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0002 # (type: float, default: 0.01) weight_decay: 0.0 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 ================================================ FILE: config_hub/pretrain/debug.yaml ================================================ # The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with # ``model_config``. (type: Optional[str], default: null) model_name: pythia-14m # A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with # ``model_config``. (type: Optional[Config], default: null) model_config: # Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in # /teamspace/jobs//share. (type: , default: out/pretrain) out_dir: out/pretrain/debug # The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-mixed # Optional path to a checkpoint directory to initialize the model from. # Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) initial_checkpoint_dir: # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. data: TinyStories # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 1000 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512) global_batch_size: 125 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 5 # Number of iterations with learning rate warmup active (type: int, default: 2000) lr_warmup_steps: 100 # Number of epochs to train on (type: Optional[int], default: null) epochs: # Total number of tokens to train on (type: Optional[int], default: 3000000000000) max_tokens: 100000000 # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) tie_embeddings: # (type: Optional[float], default: 1.0) max_norm: 1.0 # (type: float, default: 4e-05) min_lr: 6e-5 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 1000) interval: 1000 # Number of tokens to generate (type: Optional[int], default: null) max_new_tokens: # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: false # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 6e-4 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) devices: auto # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data # module require this. (type: Optional[Path], default: null) tokenizer_dir: checkpoints/EleutherAI/pythia-14m # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard) logger_name: tensorboard # The random seed to use for reproducibility. (type: int, default: 42) seed: 42 ================================================ FILE: config_hub/pretrain/microllama.yaml ================================================ # The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with # ``model_config``. (type: Optional[str], default: null) model_name: micro-llama-300M # A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with # ``model_config``. (type: Optional[Config], default: null) model_config: # Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in # /teamspace/jobs//share. (type: , default: out/pretrain) out_dir: out/pretrain/micro-llama # The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-mixed # Optional path to a checkpoint directory to initialize the model from. # Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) initial_checkpoint_dir: # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. data: MicroLlama # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 1000 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48) # Scale this number according to the number of GPU and memory size per GPU # For example, we used 48 for 4 x 24G 4090 global_batch_size: 48 # Number of samples per data-parallel rank (type: int, default: 12) # Scale this number according to the memory size per GPU # For example, we used 12 for 24G 4090 micro_batch_size: 12 # Number of iterations with learning rate warmup active (type: int, default: 2000) lr_warmup_steps: 2000 # Number of epochs to train on (type: Optional[int], default: null) epochs: # Total number of tokens to train on (type: Optional[int], default: 3000000000000) max_tokens: 3000000000000 # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 2048 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) tie_embeddings: # (type: Optional[float], default: 1.0) max_norm: 1.0 # (type: float, default: 4e-05) min_lr: 4.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 1000) interval: 1000 # Number of tokens to generate (type: Optional[int], default: null) max_new_tokens: # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 4e-4 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) devices: auto # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data # module require this. (type: Optional[Path], default: null) tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard) logger_name: tensorboard # The random seed to use for reproducibility. (type: int, default: 42) seed: 42 ================================================ FILE: config_hub/pretrain/tinyllama.yaml ================================================ # The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with # ``model_config``. (type: Optional[str], default: null) model_name: tiny-llama-1.1b # A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with # ``model_config``. (type: Optional[Config], default: null) model_config: # Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in # /teamspace/jobs//share. (type: , default: out/pretrain) out_dir: out/pretrain/tiny-llama # The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-mixed # Optional path to a checkpoint directory to initialize the model from. # Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) initial_checkpoint_dir: # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. data: TinyLlama # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 1000 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512) global_batch_size: 512 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 4 # Number of iterations with learning rate warmup active (type: int, default: 2000) lr_warmup_steps: 2000 # Number of epochs to train on (type: Optional[int], default: null) epochs: # Total number of tokens to train on (type: Optional[int], default: 3000000000000) max_tokens: 3000000000000 # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 2048 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) tie_embeddings: # (type: Optional[float], default: 1.0) max_norm: 1.0 # (type: float, default: 4e-05) min_lr: 4.0e-05 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 1000) interval: 1000 # Number of tokens to generate (type: Optional[int], default: null) max_new_tokens: # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: false # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 4e-4 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) devices: auto # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data # module require this. (type: Optional[Path], default: null) tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard) logger_name: tensorboard # The random seed to use for reproducibility. (type: int, default: 42) seed: 42 ================================================ FILE: config_hub/pretrain/tinystories.yaml ================================================ # The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with # ``model_config``. (type: Optional[str], default: null) model_name: stories15M # A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with # ``model_config``. (type: Optional[Config], default: null) model_config: name: stories15M hf_config: {} scale_embeddings: false block_size: 256 padded_vocab_size: 32000 n_layer: 6 n_head: 6 n_query_groups: 6 n_embd: 288 head_size: 48 rotary_percentage: 1.0 parallel_residual: false bias: false norm_class_name: RMSNorm mlp_class_name: LLaMAMLP intermediate_size: 768 # Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in # /teamspace/jobs//share. (type: , default: out/pretrain) out_dir: out/pretrain/stories15M # The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) precision: bf16-mixed # Optional path to a checkpoint directory to initialize the model from. # Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) initial_checkpoint_dir: # Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume # from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing # ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. # (type: Union[bool, Literal["auto"], Path], default: False) resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. data: TinyStories # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) save_interval: 1000 # Number of iterations between logging calls (type: int, default: 1) log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512) global_batch_size: 512 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 128 # Number of iterations with learning rate warmup active (type: int, default: 2000) lr_warmup_steps: 1000 # Number of epochs to train on (type: Optional[int], default: null) epochs: # Total number of tokens to train on (type: Optional[int], default: 3000000000000) max_tokens: 9700000000 # original did 298,000 iters # Limits the number of optimizer steps to run. (type: Optional[int], default: null) max_steps: # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 256 # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) tie_embeddings: true # (type: Optional[float], default: 1.0) max_norm: 1.0 # (type: float, default: 4e-05) min_lr: 0.0 # Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details eval: # Number of optimizer steps between evaluation calls (type: int, default: 1000) interval: 2000 # Number of tokens to generate (type: Optional[int], default: null) max_new_tokens: # Number of iterations (type: int, default: 100) max_iters: 100 # Whether to evaluate on the validation set at the beginning of the training initial_validation: false # Whether to evaluate on the validation set at the end the training final_validation: false # Optimizer-related arguments optimizer: class_path: torch.optim.AdamW init_args: # (type: float, default: 0.001) lr: 0.0005 # (type: float, default: 0.01) weight_decay: 0.1 # (type: tuple, default: (0.9,0.999)) betas: - 0.9 - 0.95 # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) devices: auto # How many nodes to use. (type: int, default: 1) num_nodes: 1 # Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data # module require this. (type: Optional[Path], default: null) tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf # The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard) logger_name: csv # The random seed to use for reproducibility. (type: int, default: 42) seed: 42 ================================================ FILE: extensions/thunder/README.md ================================================ # Lightning Thunder: a source-to-source compiler for PyTorch [Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) makes PyTorch programs faster both on single accelerators or in distributed settings. Thunder aims to be usable, understandable, and extensible and can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. This extension directory shows how Thunder can be used with LitGPT. > [!WARNING] > This document is an early-access development version that is currently only for internal use. We recommend users checking out the [Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) project directly, which provides more up-to-date usage information.   ## Thunder 👉👈 LitGPT: a short showcase To try Lightning Thunder with your model simply `thunder.jit()` it. ```python from litgpt import GPT import thunder import torch # Use only two layers to keep the traces shorter for the demonstration model = GPT.from_name("Llama-2-7b-hf", n_layer=2).cuda() model = thunder.jit(model) x = torch.randint(model.max_seq_length, (2, 5), device="cuda") y = model(x) # forward, this may take a bit ``` This will require some compilation time on the first forward call. ### Traces The JIT is will acquire a Python program (what we call a "trace") from the Python program (`GPT`, a `torch.nn.Module` in this example) that was given. This process targets PyTorch operators (like `Tensor.view()`, `+`, `torch.nn.functional.scaled_dot_product_atttention()`) and optionally custom operators (more about that later). We can visualize the thunder trace generated under the hood: ```python forward_trace = thunder.last_traces(model)[-1].python() print(forward_trace) ``` ```python @torch.no_grad() @no_autocast() def augmented_forward_fn(*args): # args: "Collection" t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \ t18, t19, = args del args t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False) # t24: "cuda:0 f32[2, 5, 4096]" t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1]) # t20: "cuda:0 f32[5, 128]" t21 = torch_slice_prim_impl(t2, [0, 0], [5, 128], [1, 1]) # t21: "cuda:0 f32[5, 128]" t200 = torch.unsqueeze(t11, 0) # t200: "cuda:0 f32[1, 4096]" t201 = torch.unsqueeze(t200, 1) # t201: "cuda:0 f32[1, 1, 4096]" del t200 t33 = Tensor.expand(t201, (2, 5, 4096)) # t33: "cuda:0 f32[2, 5, 4096]" del t201 t229 = torch.unsqueeze(t13, 0) # t229: "cuda:0 f32[1, 4096]" t230 = torch.unsqueeze(t229, 1) # t230: "cuda:0 f32[1, 1, 4096]" del t229 t84 = Tensor.expand(t230, (2, 5, 4096)) # t84: "cuda:0 f32[2, 5, 4096]" del t230 t232 = torch.unsqueeze(t12, 0) # t232: "cuda:0 f32[1, 4096]" t233 = torch.unsqueeze(t232, 1) # t233: "cuda:0 f32[1, 1, 4096]" del t232 t104 = Tensor.expand(t233, (2, 5, 4096)) # t104: "cuda:0 f32[2, 5, 4096]" del t233 t253 = torch.unsqueeze(t14, 0) # t253: "cuda:0 f32[1, 4096]" t254 = torch.unsqueeze(t253, 1) # t254: "cuda:0 f32[1, 1, 4096]" del t253 t155 = Tensor.expand(t254, (2, 5, 4096)) # t155: "cuda:0 f32[2, 5, 4096]" del t254 t256 = torch.unsqueeze(t10, 0) # t256: "cuda:0 f32[1, 4096]" t257 = torch.unsqueeze(t256, 1) # t257: "cuda:0 f32[1, 1, 4096]" del t256 t175 = Tensor.expand(t257, (2, 5, 4096)) # t175: "cuda:0 f32[2, 5, 4096]" del t257 t221 = torch.unsqueeze(t20, 0) # t221: "cuda:0 f32[1, 5, 128]" del t20 t222 = torch.unsqueeze(t221, 1) # t222: "cuda:0 f32[1, 1, 5, 128]" del t221 t49 = Tensor.expand(t222, (2, 32, 5, 128)) # t49: "cuda:0 f32[2, 32, 5, 128]" del t222 t224 = torch.unsqueeze(t21, 0) # t224: "cuda:0 f32[1, 5, 128]" del t21 t225 = torch.unsqueeze(t224, 1) # t225: "cuda:0 f32[1, 1, 5, 128]" del t224 t51 = Tensor.expand(t225, (2, 32, 5, 128)) # t51: "cuda:0 f32[2, 32, 5, 128]" del t225 [t30, t34] = nvFusion0(t24, t33) t35 = torch.nn.functional.linear(t34, t3, None) # t35: "cuda:0 f32[2, 5, 12288]" t36 = torch.reshape(t35, (2, 5, 32, 3, 128)) # t36: "cuda:0 f32[2, 5, 32, 3, 128]" del t35 t37 = torch.permute(t36, (0, 2, 3, 1, 4)) # t37: "cuda:0 f32[2, 32, 3, 5, 128]" del t36 (t38, t39, t40) = torch.split(t37, (1, 1, 1), 2) del t37 t41 = torch.reshape(t38, (2, 32, 5, 128)) # t41: "cuda:0 f32[2, 32, 5, 128]" del t38 t42 = torch.reshape(t39, (2, 32, 5, 128)) # t42: "cuda:0 f32[2, 32, 5, 128]" del t39 t43 = torch.reshape(t40, (2, 32, 5, 128)) # t43: "cuda:0 f32[2, 32, 5, 128]" del t40 t44 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t44: "cuda:0 f32[2, 32, 5, 128]" t54 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t54: "cuda:0 f32[2, 32, 5, 128]" t64 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t64: "cuda:0 f32[2, 32, 5, 0]" del t41 t66 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t66: "cuda:0 f32[2, 32, 5, 0]" del t42 t46 = torch_slice_prim_impl(t44, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t46: "cuda:0 f32[2, 32, 5, 64]" t45 = torch_slice_prim_impl(t44, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t45: "cuda:0 f32[2, 32, 5, 64]" t55 = torch_slice_prim_impl(t54, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t55: "cuda:0 f32[2, 32, 5, 64]" t56 = torch_slice_prim_impl(t54, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t56: "cuda:0 f32[2, 32, 5, 64]" [t47, t57] = nvFusion1(t46, t56) del t46, t56 t48 = torch.cat((t47, t45), -1) # t48: "cuda:0 f32[2, 32, 5, 128]" del t47, t45 t58 = torch.cat((t57, t55), -1) # t58: "cuda:0 f32[2, 32, 5, 128]" del t57, t55 [t53, t63] = nvFusion2(t44, t48, t49, t51, t54, t58) del t44, t48, t54, t58 t65 = torch.cat((t53, t64), -1) # t65: "cuda:0 f32[2, 32, 5, 128]" del t53, t64 t67 = torch.cat((t63, t66), -1) # t67: "cuda:0 f32[2, 32, 5, 128]" del t63, t66 (t68, t69, t70, t71) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t65, t67, t43, None, 0.0, True, 0.08838834764831843) t72 = torch.permute(t68, (0, 2, 1, 3)) # t72: "cuda:0 f32[2, 5, 32, 128]" t73 = torch.reshape(t72, (2, 5, 4096)) # t73: "cuda:0 f32[2, 5, 4096]" del t72 t74 = torch.nn.functional.linear(t73, t15, None) # t74: "cuda:0 f32[2, 5, 4096]" [t75, t81, t85] = nvFusion3(t24, t74, t84) del t74 t86 = torch.nn.functional.linear(t85, t5, None) # t86: "cuda:0 f32[2, 5, 11008]" t87 = torch.nn.functional.linear(t85, t7, None) # t87: "cuda:0 f32[2, 5, 11008]" [t93] = nvFusion4(t86, t87) t94 = torch.nn.functional.linear(t93, t16, None) # t94: "cuda:0 f32[2, 5, 4096]" [t101, t105, t95] = nvFusion5(t104, t75, t94) del t94 t106 = torch.nn.functional.linear(t105, t4, None) # t106: "cuda:0 f32[2, 5, 12288]" t107 = torch.reshape(t106, (2, 5, 32, 3, 128)) # t107: "cuda:0 f32[2, 5, 32, 3, 128]" del t106 t108 = torch.permute(t107, (0, 2, 3, 1, 4)) # t108: "cuda:0 f32[2, 32, 3, 5, 128]" del t107 (t109, t110, t111) = torch.split(t108, (1, 1, 1), 2) del t108 t112 = torch.reshape(t109, (2, 32, 5, 128)) # t112: "cuda:0 f32[2, 32, 5, 128]" del t109 t113 = torch.reshape(t110, (2, 32, 5, 128)) # t113: "cuda:0 f32[2, 32, 5, 128]" del t110 t114 = torch.reshape(t111, (2, 32, 5, 128)) # t114: "cuda:0 f32[2, 32, 5, 128]" del t111 t135 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t135: "cuda:0 f32[2, 32, 5, 0]" t137 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t137: "cuda:0 f32[2, 32, 5, 0]" t115 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t115: "cuda:0 f32[2, 32, 5, 128]" del t112 t125 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t125: "cuda:0 f32[2, 32, 5, 128]" del t113 t116 = torch_slice_prim_impl(t115, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t116: "cuda:0 f32[2, 32, 5, 64]" t117 = torch_slice_prim_impl(t115, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t117: "cuda:0 f32[2, 32, 5, 64]" t127 = torch_slice_prim_impl(t125, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t127: "cuda:0 f32[2, 32, 5, 64]" t126 = torch_slice_prim_impl(t125, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t126: "cuda:0 f32[2, 32, 5, 64]" [t118, t128] = nvFusion6(t117, t127) del t117, t127 t129 = torch.cat((t128, t126), -1) # t129: "cuda:0 f32[2, 32, 5, 128]" del t128, t126 t119 = torch.cat((t118, t116), -1) # t119: "cuda:0 f32[2, 32, 5, 128]" del t118, t116 [t124, t134] = nvFusion7(t115, t119, t125, t129, t49, t51) del t115, t119, t125, t129 t136 = torch.cat((t124, t135), -1) # t136: "cuda:0 f32[2, 32, 5, 128]" del t124, t135 t138 = torch.cat((t134, t137), -1) # t138: "cuda:0 f32[2, 32, 5, 128]" del t134, t137 (t139, t140, t141, t142) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t136, t138, t114, None, 0.0, True, 0.08838834764831843) t143 = torch.permute(t139, (0, 2, 1, 3)) # t143: "cuda:0 f32[2, 5, 32, 128]" t144 = torch.reshape(t143, (2, 5, 4096)) # t144: "cuda:0 f32[2, 5, 4096]" del t143 t145 = torch.nn.functional.linear(t144, t17, None) # t145: "cuda:0 f32[2, 5, 4096]" [t146, t152, t156] = nvFusion8(t145, t155, t95) del t145 t158 = torch.nn.functional.linear(t156, t8, None) # t158: "cuda:0 f32[2, 5, 11008]" t157 = torch.nn.functional.linear(t156, t6, None) # t157: "cuda:0 f32[2, 5, 11008]" [t164] = nvFusion9(t157, t158) t165 = torch.nn.functional.linear(t164, t18, None) # t165: "cuda:0 f32[2, 5, 4096]" [t166, t172, t176] = nvFusion10(t146, t165, t175) del t165 t177 = torch.nn.functional.linear(t176, t9, None) # t177: "cuda:0 f32[2, 5, 32000]" return {'output': t177, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19], 'flat_output': (t177,)}, ((t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95), (False, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2)) ``` This is a straight-lined version of `GPT.forward` that has been optimized. Since it's running on CUDA, the [NvFuser](https://github.com/NVIDIA/Fuser) executor has created regions (look for "nvFusion") that fuse multiple operators together. Operator fusion is very desirable with modern hardware and helps out in overhead-bound or device-bound settings by: - Launching less kernels, thus reducing the kernel launch overhead. - Reducing the number of memory accesses performed by reusing them in a fused operation - Minimizing host-device communications Thunder also uses a multi-level intermediate representation. If we let it print all levels ```python forward_trace = thunder.last_traces(model)[-1] print(forward_trace) ``` We can see as comments the primitives that compose the fusion regions. For instance, this is the region associated to [the `RMSNorm` implementation](https://github.com/Lightning-AI/litgpt/blob/9b6475dabf90c7acee506a026bd9fa86251835bf/litgpt/model.py#L409-L420) ```python [t146, t152, t156] = nvFusion8(t145, t155, t95) # t146 = prims.add(t145, t95) # t146: "cuda:0 f32[2, 5, 4096]" # t147 = prims.mul(t146, t146) # t147: "cuda:0 f32[2, 5, 4096]" # t148 = prims.sum(t147, (2,)) # t148: "cuda:0 f32[2, 5]" # t149 = prims.broadcast_in_dim(t148, [2, 5, 1], [0, 1]) # t149: "cuda:0 f32[2, 5, 1]" # t150 = prims.div(t149, 4096.0) # t150: "cuda:0 f32[2, 5, 1]" # t151 = prims.add(t150, 1e-05) # t151: "cuda:0 f32[2, 5, 1]" # t152 = prims.rsqrt(t151) # t152: "cuda:0 f32[2, 5, 1]" # t153 = prims.broadcast_in_dim(t152, (2, 5, 4096), (0, 1, 2)) # t153: "cuda:0 f32[2, 5, 4096]" # t154 = prims.mul(t146, t153) # t154: "cuda:0 f32[2, 5, 4096]" # t156 = prims.mul(t154, t155) # t156: "cuda:0 f32[2, 5, 4096]" ``` Similarly, we can visualize the backward trace: ```python backward_trace = thunder.last_backward_traces(model)[-1].python() print(backward_trace) ``` ```python @torch.no_grad() @no_autocast() def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, C1, = saved_for_backward clear_collection(saved_for_backward) del saved_for_backward t178, = cotangents clear_collection(cotangents) del cotangents t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \ t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \ t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \ t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0 clear_collection(C0) del C0 b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \ = C1 clear_collection(C1) del C1 t639 = torch.reshape(t178, (-1, 32000)) # t639: "cuda:0 f32[10, 32000]" del t178 t643 = torch.permute(t639, (1, 0)) # t643: "cuda:0 f32[32000, 10]" t644 = torch.reshape(t176, (-1, 4096)) # t644: "cuda:0 f32[10, 4096]" del t176 t669 = torch.reshape(t164, (-1, 11008)) # t669: "cuda:0 f32[10, 11008]" del t164 t686 = torch.reshape(t156, (-1, 4096)) # t686: "cuda:0 f32[10, 4096]" del t156 t720 = torch.reshape(t144, (-1, 4096)) # t720: "cuda:0 f32[10, 4096]" del t144 t776 = torch.reshape(t105, (-1, 4096)) # t776: "cuda:0 f32[10, 4096]" del t105 t802 = torch.reshape(t93, (-1, 11008)) # t802: "cuda:0 f32[10, 11008]" del t93 t819 = torch.reshape(t85, (-1, 4096)) # t819: "cuda:0 f32[10, 4096]" del t85 t853 = torch.reshape(t73, (-1, 4096)) # t853: "cuda:0 f32[10, 4096]" del t73 t911 = torch.reshape(t34, (-1, 4096)) # t911: "cuda:0 f32[10, 4096]" del t34 t640 = torch.matmul(t639, t9) # t640: "cuda:0 f32[10, 4096]" del t639, t9 t645 = torch.matmul(t643, t644) # t645: "cuda:0 f32[32000, 4096]" del t643, t644 t641 = torch.reshape(t640, (2, 5, 4096)) # t641: "cuda:0 f32[2, 5, 4096]" del t640 [t648, t663] = nvFusion0(f106, t166, t172, t175, t641) del f106, t166, t172, t175, t641 t664 = torch.reshape(t663, (-1, 4096)) # t664: "cuda:0 f32[10, 4096]" t668 = torch.permute(t664, (1, 0)) # t668: "cuda:0 f32[4096, 10]" t665 = torch.matmul(t664, t18) # t665: "cuda:0 f32[10, 11008]" del t664, t18 t670 = torch.matmul(t668, t669) # t670: "cuda:0 f32[4096, 11008]" del t668, t669 t666 = torch.reshape(t665, (2, 5, 11008)) # t666: "cuda:0 f32[2, 5, 11008]" del t665 [t672, t680] = nvFusion1(t157, t158, t666) del t157, t158, t666 t681 = torch.reshape(t672, (-1, 11008)) # t681: "cuda:0 f32[10, 11008]" del t672 t685 = torch.permute(t681, (1, 0)) # t685: "cuda:0 f32[11008, 10]" t688 = torch.reshape(t680, (-1, 11008)) # t688: "cuda:0 f32[10, 11008]" del t680 t692 = torch.permute(t688, (1, 0)) # t692: "cuda:0 f32[11008, 10]" t689 = torch.matmul(t688, t6) # t689: "cuda:0 f32[10, 4096]" del t688, t6 t682 = torch.matmul(t681, t8) # t682: "cuda:0 f32[10, 4096]" del t681, t8 t694 = torch.matmul(t692, t686) # t694: "cuda:0 f32[11008, 4096]" del t692 t687 = torch.matmul(t685, t686) # t687: "cuda:0 f32[11008, 4096]" del t685, t686 t683 = torch.reshape(t682, (2, 5, 4096)) # t683: "cuda:0 f32[2, 5, 4096]" del t682 t690 = torch.reshape(t689, (2, 5, 4096)) # t690: "cuda:0 f32[2, 5, 4096]" del t689 [t698, t714] = nvFusion2(f101, t146, t152, t155, t663, t683, t690) del f101, t146, t152, t155, t663, t683, t690 t715 = torch.reshape(t714, (-1, 4096)) # t715: "cuda:0 f32[10, 4096]" t719 = torch.permute(t715, (1, 0)) # t719: "cuda:0 f32[4096, 10]" t716 = torch.matmul(t715, t17) # t716: "cuda:0 f32[10, 4096]" del t715, t17 t721 = torch.matmul(t719, t720) # t721: "cuda:0 f32[4096, 4096]" del t719, t720 t717 = torch.reshape(t716, (2, 5, 4096)) # t717: "cuda:0 f32[2, 5, 4096]" del t716 t722 = torch.reshape(t717, (2, 5, 32, 128)) # t722: "cuda:0 f32[2, 5, 32, 128]" del t717 t723 = torch.permute(t722, (0, 2, 1, 3)) # t723: "cuda:0 f32[2, 32, 5, 128]" del t722 (t724, t725, t726, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t723, t136, t138, t114, None, t139, t140, t141, t142, f90, b91, scale=f92) del t723, t136, t138, t114, t139, t140, t141, t142, f90, b91, f92 t765 = torch.reshape(t726, (2, 32, 1, 5, 128)) # t765: "cuda:0 f32[2, 32, 1, 5, 128]" del t726 t727 = torch_slice_prim_impl(t725, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t727: "cuda:0 f32[2, 32, 5, 128]" del t725 t730 = torch_slice_prim_impl(t724, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t730: "cuda:0 f32[2, 32, 5, 128]" del t724 [t747, t764] = nvFusion3(t49, t51, t727, t730) del t727, t730 t766 = torch.reshape(t747, (2, 32, 1, 5, 128)) # t766: "cuda:0 f32[2, 32, 1, 5, 128]" del t747 t767 = torch.reshape(t764, (2, 32, 1, 5, 128)) # t767: "cuda:0 f32[2, 32, 1, 5, 128]" del t764 t768 = torch.cat((t767, t766, t765), i73) # t768: "cuda:0 f32[2, 32, 3, 5, 128]" del t767, t766, t765, i73 t769 = torch.permute(t768, (0, 3, 1, 2, 4)) # t769: "cuda:0 f32[2, 5, 32, 3, 128]" del t768 t770 = torch.reshape(t769, (2, 5, 12288)) # t770: "cuda:0 f32[2, 5, 12288]" del t769 t771 = torch.reshape(t770, (-1, 12288)) # t771: "cuda:0 f32[10, 12288]" del t770 t775 = torch.permute(t771, (1, 0)) # t775: "cuda:0 f32[12288, 10]" t777 = torch.matmul(t775, t776) # t777: "cuda:0 f32[12288, 4096]" del t775, t776 t772 = torch.matmul(t771, t4) # t772: "cuda:0 f32[10, 4096]" del t771, t4 t773 = torch.reshape(t772, (2, 5, 4096)) # t773: "cuda:0 f32[2, 5, 4096]" del t772 [t780, t796] = nvFusion4(f56, t101, t104, t714, t773, t95) del f56, t101, t104, t714, t773, t95 t797 = torch.reshape(t796, (-1, 4096)) # t797: "cuda:0 f32[10, 4096]" t801 = torch.permute(t797, (1, 0)) # t801: "cuda:0 f32[4096, 10]" t798 = torch.matmul(t797, t16) # t798: "cuda:0 f32[10, 11008]" del t797, t16 t803 = torch.matmul(t801, t802) # t803: "cuda:0 f32[4096, 11008]" del t801, t802 t799 = torch.reshape(t798, (2, 5, 11008)) # t799: "cuda:0 f32[2, 5, 11008]" del t798 [t805, t813] = nvFusion5(t799, t86, t87) del t799, t86, t87 t814 = torch.reshape(t805, (-1, 11008)) # t814: "cuda:0 f32[10, 11008]" del t805 t818 = torch.permute(t814, (1, 0)) # t818: "cuda:0 f32[11008, 10]" t821 = torch.reshape(t813, (-1, 11008)) # t821: "cuda:0 f32[10, 11008]" del t813 t825 = torch.permute(t821, (1, 0)) # t825: "cuda:0 f32[11008, 10]" t822 = torch.matmul(t821, t5) # t822: "cuda:0 f32[10, 4096]" del t821, t5 t815 = torch.matmul(t814, t7) # t815: "cuda:0 f32[10, 4096]" del t814, t7 t827 = torch.matmul(t825, t819) # t827: "cuda:0 f32[11008, 4096]" del t825 t820 = torch.matmul(t818, t819) # t820: "cuda:0 f32[11008, 4096]" del t818, t819 t816 = torch.reshape(t815, (2, 5, 4096)) # t816: "cuda:0 f32[2, 5, 4096]" del t815 t823 = torch.reshape(t822, (2, 5, 4096)) # t823: "cuda:0 f32[2, 5, 4096]" del t822 [t831, t847] = nvFusion6(f51, t75, t796, t81, t816, t823, t84) del f51, t75, t796, t81, t816, t823, t84 t848 = torch.reshape(t847, (-1, 4096)) # t848: "cuda:0 f32[10, 4096]" t852 = torch.permute(t848, (1, 0)) # t852: "cuda:0 f32[4096, 10]" t849 = torch.matmul(t848, t15) # t849: "cuda:0 f32[10, 4096]" del t848, t15 t854 = torch.matmul(t852, t853) # t854: "cuda:0 f32[4096, 4096]" del t852, t853 t850 = torch.reshape(t849, (2, 5, 4096)) # t850: "cuda:0 f32[2, 5, 4096]" del t849 t855 = torch.reshape(t850, (2, 5, 32, 128)) # t855: "cuda:0 f32[2, 5, 32, 128]" del t850 t856 = torch.permute(t855, (0, 2, 1, 3)) # t856: "cuda:0 f32[2, 32, 5, 128]" del t855 (t857, t858, t859, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t856, t65, t67, t43, None, t68, t69, t70, t71, f40, b41, scale=f42) del t856, t65, t67, t43, t68, t69, t70, t71, f40, b41, f42 t900 = torch.reshape(t859, (2, 32, 1, 5, 128)) # t900: "cuda:0 f32[2, 32, 1, 5, 128]" del t859 t863 = torch_slice_prim_impl(t857, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t863: "cuda:0 f32[2, 32, 5, 128]" del t857 t860 = torch_slice_prim_impl(t858, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t860: "cuda:0 f32[2, 32, 5, 128]" del t858 [t882, t899] = nvFusion7(t49, t51, t860, t863) del t49, t51, t860, t863 t902 = torch.reshape(t899, (2, 32, 1, 5, 128)) # t902: "cuda:0 f32[2, 32, 1, 5, 128]" del t899 t901 = torch.reshape(t882, (2, 32, 1, 5, 128)) # t901: "cuda:0 f32[2, 32, 1, 5, 128]" del t882 t903 = torch.cat((t902, t901, t900), i23) # t903: "cuda:0 f32[2, 32, 3, 5, 128]" del t902, t901, t900, i23 t904 = torch.permute(t903, (0, 3, 1, 2, 4)) # t904: "cuda:0 f32[2, 5, 32, 3, 128]" del t903 t905 = torch.reshape(t904, (2, 5, 12288)) # t905: "cuda:0 f32[2, 5, 12288]" del t904 t906 = torch.reshape(t905, (-1, 12288)) # t906: "cuda:0 f32[10, 12288]" del t905 t910 = torch.permute(t906, (1, 0)) # t910: "cuda:0 f32[12288, 10]" t907 = torch.matmul(t906, t3) # t907: "cuda:0 f32[10, 4096]" del t906, t3 t912 = torch.matmul(t910, t911) # t912: "cuda:0 f32[12288, 4096]" del t910, t911 t908 = torch.reshape(t907, (2, 5, 4096)) # t908: "cuda:0 f32[2, 5, 4096]" del t907 [t915, t931] = nvFusion8(f6, t24, t30, t33, t847, t908) del f6, t24, t30, t33, t847, t908 t932 = torch.torch.ops.aten.embedding_backward(t931, t0, i0, -1, b1, b2) # t932: "cuda:0 f32[32000, 4096]" del t931, t0, i0, b1, b2 return (None, None, None, t912, t777, t827, t694, t820, t687, t645, t648, t915, t780, t831, t698, t854, t803, t721, t670, t932) ``` These traces are long, and require some familiarity with the model implementation to follow them, but they allow you to: - Inspect exactly what operations are run including their decompositions. - Inspect the sizes of tensors, their device, data type and conversions. - Apply transformations to the traces since the computations are completely decoupled from the data. - Inspect the backward operations generated for each forward operation to understand what autograd is doing. ### Transforms Transforms are one of the core features of Thunder. For example, they enable easy data parallel distribution. That is replicated data parallelism (DDP) and fully-sharded data parallelism (FSDP). We provide ready-to-use Fabric strategies that integrate Thunder DDP|FSDP. Under the hood, the code is quite straightforward: ```python model = thunder.distributed.ddp(model) # or # model = thunder.distributed.fsdp(model) model = thunder.jit(model) ``` After applying the DDP transformation, the backward trace will include the expected all-reduce collectives: ```python p1022 = torch_all_reduce_prim_impl(t1021, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True, False) # p1022: "FUTURE cuda:0 f32[16797696]" ... t1059 = torch_wait_prim_impl(p1025) # t1059: "cuda:0 f32[131072000]" ``` With `L.Fabric`, this is how to use them: ```python from extensions.extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrategy # fully-sharded data parallel strategy = ThunderFSDPStrategy( sharding_strategy="ZERO3", bucketing_strategy="BLOCK", executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"), state_dict_type="full", ) # replicated data parallel strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile_cat", "nvfuser", "torch")) fabric = L.Fabric(devices=devices, strategy=strategy) fabric.launch() model = fabric.setup(model) # JIT is called here ``` And in the case of FSDP all-gathers in forward and reduce-scatters in backward. Meaning that Thunder automatically introduced the necessary collective operations to support data parallelism. ### Executors Thunder allows you to define a priority list of executors that can map operators: ```python import thunder model = thunder.jit( model, executors=["sdpa", "torchcompile_cat", "nvfuser", "torch"] ) ``` Notice how `torch.compile` is a valid executor. This executor registers a few operators with improved performance so that you can utilize the fastest set of operator implementations possible. ### Custom executors Lightning Thunder provides extension points to integrate fast kernels for operators in your model without having to modify your implementation. For instance, the [Unsloth project](https://github.com/unslothai/unsloth/) provides several Triton kernels that can be used with LitGPT: - Cross entropy loss - SwiGLU (part of `LLaMAMLP`) - RoPE The [`unsloth` directory](unsloth) contains a [custom executor](unsloth/executor.py) that registers these operators for LitGPT. We can enable this executor by passing it to the list of executors available. The order matters because we want to run its custom operators before `NvFuser` creates its fusion regions. ```python import thunder model = thunder.jit( model, executors=["sdpa", "unsloth", "torchcompile_cat", "nvfuser", "torch"] ) ``` Doing this, the model trace now includes the Unsloth kernel calls: ```python def augmented_forward_fn(*args): ... (t121, _, _, _, _, _) = unsloth_apply_rope(t120, t21, t22) ... (t189, t190) = unsloth_cross_entropy(t187, t188) ... def backward_fn(saved_for_backward, cotangents): ... t652 = unsloth_cross_entropy_backward(t651, t187, t188, t190) # t652: "cuda:0 f32[6, 320]" ... t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4) # t763: "cuda:0 f32[2, 4, 3, 16]" ``` We provide a specific [pre-training script copy](pretrain.py) that uses this executor. Given the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like [NvFuser](https://github.com/NVIDIA/Fuser). ## Examples and benchmarks > [!WARNING] > Lightning Thunder is alpha and not ready for production runs. Feel free to try it out, expect a few bumps along the way. > We expect speed and memory usage to improve as we continue to develop it. We provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM. | Setting | Compiler | Executors | Devices | ms/iter @ step 10 | Memory (GB) | |----------------------|----------|----------------------------------------|---------|-------------------|---------------| | Fully-sharded ZeRO 3 | Eager | - | 8 | 456.57 | 22.13 | | Fully-sharded ZeRO 3 | torch | - | 8 | Not supported | Not supported | | Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile | 8 | Not supported | Not supported | | Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 333.56 | 21.40 | | | | | | | | | Replicated | Eager | - | 8 | 569.46 | 32.04 | | Replicated | torch | - | 8 | Not supported | Not supported | | Replicated | Thunder | sdpa, torchcompile | 8 | 426.44 | 22.19 | | Replicated | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 356.01 | 27.42 | | | | | | | | | - | Eager | - | 1 | 447.65 | 29.84 | | - | torch | - | 1 | Not supported | Not supported | | - | Thunder | sdpa, torchcompile | 1 | 373.37 | 22.19 | | - | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 322.25 | 27.42 | | | | | | | | | Unsloth | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 331.92 | 25.19 |
Reproduction details Config: ```yaml out_dir: out/pretrain-thunder data: TinyStories tokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0 logger_name: csv ``` Commands: ```bash litgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32 python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32 python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32 python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1 python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1 python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1 python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1 ``` `--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601 The CUDA devices are all NVIDIA A100-SXM4-40GB. ```text Python version: 3.10.12 [GCC 11.4.0] (64-bit runtime) Is debug build: False CUDA used to build PyTorch: 12.1 CUDA runtime version: 12.3.107 Nvidia driver version: 545.23.08 pytorch-triton==3.0.0+45fff310c8 torch==2.4.0.dev20240427+cu121 lightning==2.3.0.dev20240328 lightning-thunder==0.2.0.dev20240505 nvfuser_cu121==0.2.3.dev20240428 ```
================================================ FILE: extensions/thunder/__init__.py ================================================ import sys from pathlib import Path # support running without installing as a package, adding extensions to the Python path wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) ================================================ FILE: extensions/thunder/pretrain.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import math import os import pprint import sys import time from dataclasses import asdict from datetime import timedelta from functools import partial from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import lightning as L import torch import torch.nn as nn from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops from torch.utils.data import DataLoader from torchmetrics.aggregation import RunningMean from typing_extensions import Literal from litgpt import Tokenizer from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention from litgpt.parser_config import save_hyperparameters from litgpt.types import LoggerChoice from litgpt.utils import ( CLI, CycleIterator, capture_hparams, choose_logger, chunked_cross_entropy, copy_config_files, find_resume_path, instantiate_torch_optimizer, num_parameters, parse_devices, reset_parameters, save_config, ) # support running without installing as a package wd = Path(__file__).parent.resolve() sys.path.append(str(wd)) def forward_and_loss(model: nn.Module, input_ids: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: logits = model(input_ids) # disable chunk_size to enable the unsloth cross entropy kernel loss = chunked_cross_entropy(logits, targets, chunk_size=0) return loss def setup( model_name: Optional[str] = None, model_config: Optional[Config] = None, out_dir: Path = Path("out/pretrain"), initial_checkpoint_dir: Optional[Path] = None, resume: Union[bool, Literal["auto"], Path] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=512, micro_batch_size=4, max_tokens=int(3e12), # 3 trillion max_norm=1.0, min_lr=4e-5, lr_warmup_steps=2000, tie_embeddings=False, ), eval: EvalArgs = EvalArgs(interval=1000, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", devices: Union[int, str] = "auto", num_nodes: int = 1, tokenizer_dir: Optional[Path] = None, logger_name: LoggerChoice = "tensorboard", seed: int = 42, compiler: Optional[Literal["thunder", "torch"]] = "thunder", executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"), strategy: Literal["auto", "ddp", "fsdp"] = "fsdp", ): """Pretrain a model. Arguments: model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with ``model_config``. model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with ``model_config``. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from. Useful for continued pretraining. Mutually exclusive with ``resume``. resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. devices: How many devices/GPUs to use. Uses all GPUs by default. num_nodes: How many nodes the code is being run on. tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data module require this. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. compiler: If desired, the compiler/JIT to use. executors: If using Thunder, the executors to enable. strategy: If desired, the strategy to use. """ hparams = capture_hparams() data = TinyLlama() if data is None else data if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") elif model_config is None and model_name is None: model_name = "tiny-llama-1.1b" config = Config.from_name(model_name) if model_config is None else model_config devices = parse_devices(devices) out_dir = init_out_dir(out_dir) # in case the dataset requires the Tokenizer tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None logger = choose_logger( logger_name, out_dir, name=f"pretrain-{config.name}", resume=bool(resume), log_interval=train.log_interval, log_args=asdict(log), ) if devices * num_nodes > 1: if compiler == "thunder": if strategy == "fsdp": from extensions.thunder.strategies import ThunderFSDPStrategy strategy = ThunderFSDPStrategy( sharding_strategy="ZERO3", bucketing_strategy="BLOCK", state_dict_type="full", jit=False, ) elif strategy == "ddp": from extensions.thunder.strategies import ThunderDDPStrategy strategy = ThunderDDPStrategy(jit=False) else: if strategy == "fsdp": strategy = FSDPStrategy( auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="FULL_SHARD" ) else: strategy = "auto" fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision="bf16-true", loggers=[logger]) fabric.launch() if compiler is not None: global forward_and_loss forward_and_loss = ( jit(forward_and_loss, executors) if compiler == "thunder" else torch.compile(forward_and_loss) ) fabric.print(pprint.pformat(hparams)) if logger_name in ("tensorboard", "wandb", "mlflow"): fabric.logger.log_hyperparams(hparams) main( fabric=fabric, devices=devices, num_nodes=num_nodes, seed=seed, initial_checkpoint_dir=initial_checkpoint_dir, resume=resume, config=config, data=data, out_dir=out_dir, tokenizer_dir=tokenizer_dir, tokenizer=tokenizer, train=train, eval=eval, optimizer=optimizer, compiler=compiler, ) def main( fabric: L.Fabric, devices: int, seed: int, initial_checkpoint_dir: Optional[Path], resume: Union[bool, Literal["auto"], Path], config: Config, data: DataModule, out_dir: Path, tokenizer_dir: Optional[Path], tokenizer: Optional[Tokenizer], train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], compiler: Optional[Literal["thunder", "torch"]], num_nodes: int = 1, ) -> None: validate_args(train, eval, initial_checkpoint_dir, resume) if fabric.global_rank == 0: out_dir.mkdir(parents=True, exist_ok=True) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd) if train.tie_embeddings: model.transformer.wte.weight = model.lm_head.weight if train.max_seq_length: model.max_seq_length = train.max_seq_length fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters: {num_parameters(model):,}") model = fabric.setup(model) if compiler == "thunder": # avoid `Tensor.register_hook` which is unsupported model._register_backward_hook = lambda *_: None optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length) train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) if initial_checkpoint_dir: fabric.load_raw(initial_checkpoint_dir / "lit_model.pth", model) state = { "model": model, "optimizer": optimizer, "train_dataloader": train_dataloader, "iter_num": 0, "step_count": 0, } resume = find_resume_path(resume, out_dir) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) train_time = time.perf_counter() fit( fabric=fabric, devices=devices, num_nodes=num_nodes, state=state, train_dataloader=train_dataloader, val_dataloader=val_dataloader, out_dir=out_dir, tokenizer_dir=tokenizer_dir, train=train, eval=eval, optimizer=optimizer, ) fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s") # Save final checkpoint save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") def fit( fabric: L.Fabric, devices: int, state: dict, train_dataloader: DataLoader, val_dataloader: DataLoader, out_dir: Path, tokenizer_dir: Optional[Path], train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: model = state["model"] optimizer = state["optimizer"] validate(fabric, model, val_dataloader, max_iters=2) # sanity check throughput = ThroughputMonitor(fabric, window_size=5) with torch.device("meta"): meta_model = GPT(model.config) x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length)) model_fwd = lambda: meta_model(x) # noqa: F821 model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821 measured_flops = measure_flops(meta_model, model_fwd, model_loss) fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x max_tokens_per_device = train.max_tokens // fabric.world_size tokens_per_iter = train.micro_batch_size * model.max_seq_length max_iters = max_tokens_per_device // tokens_per_iter log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes) initial_iter = state["iter_num"] train_iterator = CycleIterator(train_dataloader) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() total_t0 = time.perf_counter() val_loss = "n/a" warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader) for train_data in train_iterator: if state["iter_num"] >= max_iters: break # determine and set the learning rate for this iteration lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr) for param_group in optimizer.param_groups: param_group["lr"] = lr state["iter_num"] += 1 iter_t0 = time.perf_counter() input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long() targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long() is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): loss = forward_and_loss(model, input_ids, targets) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: # THUNDER unsupported: https://github.com/Lightning-AI/lightning-thunder/issues/2357 # fabric.clip_gradients(model, optimizer, max_norm=train.max_norm) optimizer.step() optimizer.zero_grad() state["step_count"] += 1 if state["iter_num"] % log_iter_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=(t1 - total_t0), flops=(measured_flops * log_iter_interval), batches=state["iter_num"], samples=(state["iter_num"] * train.micro_batch_size), lengths=(state["iter_num"] * train.micro_batch_size * model.max_seq_length), ) metrics = { "loss": loss, "iter": state["iter_num"], "step": state["step_count"], "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "remaining_time": ( (t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"]) ), "tokens": state["iter_num"] * train.micro_batch_size * model.max_seq_length, "total_tokens": (state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size), "learning_rate": lr, } if isinstance(val_loss, float): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" f" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}" ) throughput_metrics = throughput.compute() metrics.update(throughput_metrics) fabric.log_dict(metrics, step=state["iter_num"] - 1) if val_dataloader is not None and not is_accumulating and state["step_count"] % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters) val_loss = val_loss.item() td = time.perf_counter() - t0 fabric.print(f"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms") metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=state["iter_num"] - 1) fabric.barrier() if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0: save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth") @torch.no_grad() def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor: fabric.barrier() fabric.print("Validating ...") model.eval() losses = [] for k, batch in enumerate(val_dataloader): if k >= max_iters: break input_ids = batch[:, 0 : model.max_seq_length].contiguous().long() targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long() loss = forward_and_loss(model, input_ids, targets) losses.append(loss) val_loss = torch.stack(losses).mean() model.train() fabric.barrier() return val_loss def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() return train_dataloader, val_dataloader # learning rate decay scheduler (cosine with linear warmup) def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float: # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > max_iters, return min learning rate if it > max_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) def initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None: """GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745).""" # Adapted from https://github.com/jzhang38/TinyLlama def init_weights(module, std): nn.init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: nn.init.zeros_(module.bias) for mod in model.modules(): if isinstance(mod, (nn.Embedding, nn.Linear)): mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd)) # need a separate loop because `mod.proj` below is a `nn.Linear` too for mod in model.modules(): if isinstance(mod, (LLaMAMLP, CausalSelfAttention, MultiheadLatentAttention)): mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer)) if not isinstance(fabric.strategy, FSDPStrategy): reset_parameters(model) def init_out_dir(out_dir: Path) -> Path: if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ: return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir return out_dir def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file): model = state["model"] checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") fabric.save(checkpoint_file, state) if fabric.global_rank == 0: save_hyperparameters(setup, checkpoint_file.parent) if tokenizer_dir is not None: copy_config_files(tokenizer_dir, checkpoint_file.parent) save_config(model.config, checkpoint_file.parent) def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None: issues = [] unsupported = [(train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["max_tokens", "max_norm"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if initial_checkpoint_dir and resume: issues.append("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.") if issues: raise ValueError("\n".join(issues)) def jit(fn: Callable, executors: List[str]) -> Any: assert executors is not None from unsloth.executor import unsloth_ex # import for registration # noqa: F401 import thunder return thunder.jit(fn, executors=executors) if __name__ == "__main__": torch.set_float32_matmul_precision("high") CLI(setup) ================================================ FILE: extensions/thunder/strategies/__init__.py ================================================ from .thunder_ddp import ThunderDDPStrategy # noqa: F401 from .thunder_fsdp import ThunderFSDPStrategy # noqa: F401 ================================================ FILE: extensions/thunder/strategies/thunder_ddp.py ================================================ """Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually.""" from contextlib import nullcontext from datetime import timedelta from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple, Union import torch import torch.distributed from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl from lightning.fabric.utilities.distributed import ( ReduceOp, _distributed_is_initialized, _get_default_process_group_backend_for_device, _init_dist_connection, _sync_ddp_if_available, ) from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from typing_extensions import override from litgpt.constants import _THUNDER_AVAILABLE if TYPE_CHECKING: from thunder import Executor class ThunderDDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, jit: bool = True, executors: Optional[Tuple[Union["Executor", str], ...]] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, **kwargs: Any, ): r"""Strategy for Replicated Data Parallel provided by Lightning Thunder. .. warning:: This is an :ref:`experimental ` feature. Arguments: jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually jitting a function that includes the model. executors: The list of Thunder executors to enable. They can be either string aliases for the executors or the actual executor instances. \**kwargs: See available parameters in :func:`thunder.distributed.ddp`. """ if not _THUNDER_AVAILABLE: raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment if not jit and executors is not None: raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") self.jit = jit self.executors = executors self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _ThunderDataParalellBackwardSyncControl() self._ddp_kwargs = kwargs @property @override def root_device(self) -> torch.device: assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property def num_nodes(self) -> int: return self._num_nodes @num_nodes.setter def num_nodes(self, num_nodes: int) -> None: # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks self._num_nodes = num_nodes @property def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property @override def distributed_sampler_kwargs(self) -> Dict[str, Any]: return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} @override def _configure_launcher(self) -> None: assert self.cluster_environment is not None if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) @property def process_group_backend(self) -> Optional[str]: return self._process_group_backend @override def _configure_launcher(self) -> None: assert self.cluster_environment is not None self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) @override def setup_environment(self) -> None: super().setup_environment() self._setup_distributed() @override def setup_module(self, module: Module) -> Module: import thunder if (cd := thunder.compile_data(module)) is not None: # the module was already jitted if thunder.compile_stats(module).last_traces is not None: raise RuntimeError( "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" " DDP transform. Remove the `forward` call before `fabric.setup()`" ) assert cd.is_module # sanity check ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs) # update the compile data state cd.fn = ddp_module cd.process_group_for_ddp = ddp_module.process_group_for_ddp return module else: module = thunder.distributed.ddp(module, **self._ddp_kwargs) if not self.jit: return module return thunder.jit(module, executors=self.executors) @override def module_to_device(self, module: Module) -> None: module.to(self.root_device) @override def all_reduce( self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> Tensor: if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor @override def barrier(self, *args: Any, **kwargs: Any) -> None: if not _distributed_is_initialized(): return if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=[self.root_device.index]) else: torch.distributed.barrier() @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not _distributed_is_initialized(): return obj obj = [obj] torch.distributed.broadcast_object_list(obj, src) return obj[0] def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def _set_world_ranks(self) -> None: if self.cluster_environment is not None: self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank class _ThunderDataParalellBackwardSyncControl(_BackwardSyncControl): def __init__(self): self._enabled = False @override def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """ In Thunder, we cannot use ``module.no_sync()`` because reduction happens at the end of the context manager. It assumes that the user will reuse it across all gradient accumulation iterations: .. code-block:: python with model.no_sync(): for _ in range(len(gradient_accumulation_iters)): fwd() bwd() # uses no-sync-backward trace fwd() bwd() # uses regular-backward trace However, Fabric is designed to the context manager every iteration: .. code-block:: python for i in range(iters): is_accumulating = (i + 1) % gradient_accumulation_iters != 0 ctx = model.no_sync() if is_accumulating else nullcontext() with ctx: fwd() bwd() So we need to be smart about when to sync grads based on the ``enabled`` value. More info in https://github.com/Lightning-AI/lit-thunder-LEGACY/issues/2085 """ if not getattr(module, "use_ddp", False) and not getattr(module, "use_fsdp", False): raise TypeError( "Blocking backward sync is only possible if the module passed to" f" `{self.__class__.__name__}.no_backward_sync` is applied DDP or FSDP." f" Got: {module.__class__.__name__}." ) from thunder.distributed import skip_data_parallel_grad_sync previous, self._enabled = self._enabled, enabled if enabled: return skip_data_parallel_grad_sync() if not enabled and previous: return _SyncGradsContextManager(module) return nullcontext() class _SyncGradsContextManager: def __init__(self, module: Module) -> None: self._module = module @override def __enter__(self) -> None: from thunder.distributed import _sync_grads _sync_grads(self._module) @override def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass ================================================ FILE: extensions/thunder/strategies/thunder_fsdp.py ================================================ """Fabric Strategy to support Thunder FSDP: To be upstreamed into Fabric eventually.""" import shutil from contextlib import ExitStack, nullcontext from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Tuple, Union import torch from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.strategies.parallel import ParallelStrategy from lightning.fabric.strategies.strategy import TBroadcast, _apply_filter, _Sharded, _validate_keys_for_strict_loading from lightning.fabric.utilities.distributed import ( ReduceOp, _distributed_is_initialized, _get_default_process_group_backend_for_device, _init_dist_connection, _sync_ddp_if_available, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.load import _METADATA_FILENAME, _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, _Stateful from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from typing_extensions import override from extensions.thunder.strategies.thunder_ddp import _ThunderDataParalellBackwardSyncControl from litgpt.constants import _THUNDER_AVAILABLE if TYPE_CHECKING: from thunder import Executor from thunder.distributed import FSDPBucketingStrategy, FSDPType from thunder.distributed.checkpoint import StateDictOptions _FSDP_TYPE = Union[FSDPType, Literal["ZERO2", "ZERO3"]] _BUCKETING_STRATEGY = Union[FSDPBucketingStrategy, Literal["NONE", "LAYER", "BLOCK"]] class ThunderFSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, jit: bool = True, executors: Optional[Tuple[Union["Executor", str], ...]] = None, sharding_strategy: "_FSDP_TYPE" = "ZERO3", bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE", state_dict_type: Literal["full", "sharded"] = "sharded", **kwargs: Any, ): r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder. .. warning:: This is an :ref:`experimental ` feature. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. Arguments: jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually jitting a function that includes the model. executors: The list of Thunder executors to enable. They can be either string aliases for the executors or the actual executor instances. sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of them: - ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default). - ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated. Also accepts a :class:`thunder.distributed.FSDPType` enum value. bucketing_strategy: Enables combining the collective operations for sets of layers. - ``"NONE"``: No bucketing (default). - ``"LAYER"``: Create buckets per layer class. - ``"BLOCK"``: Create buckets per layer block. Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value. state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file (default). - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. \**kwargs: See available parameters in :func:`thunder.distributed.fsdp`. """ if not _TORCH_GREATER_EQUAL_2_2: raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.") if not _THUNDER_AVAILABLE: raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment from thunder.distributed import FSDPBucketingStrategy, FSDPType self.sharding_strategy = ( FSDPType[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy ) self.bucketing_strategy = ( FSDPBucketingStrategy[bucketing_strategy.upper()] if isinstance(bucketing_strategy, str) else bucketing_strategy ) if not jit and executors is not None: raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") self.jit = jit self.executors = executors self._state_dict_type = state_dict_type self._backward_sync_control = _ThunderDataParalellBackwardSyncControl() self._fsdp_kwargs = kwargs @property @override def root_device(self) -> torch.device: assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property def num_nodes(self) -> int: return 1 @property def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property @override def distributed_sampler_kwargs(self) -> Dict[str, Any]: return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} @override def _configure_launcher(self) -> None: assert self.cluster_environment is not None if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) @override def setup_environment(self) -> None: super().setup_environment() self._setup_distributed() @override def setup_module(self, module: Module) -> Module: import thunder if (cd := thunder.compile_data(module)) is not None: # the module was already jitted if thunder.compile_stats(module).last_traces is not None: raise RuntimeError( "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" " FSDP transform. Remove the `forward` call before `fabric.setup()`" ) assert cd.is_module # sanity check fsdp_module = thunder.distributed.fsdp( cd.fn, device=self.root_device, sharding_strategy=self.sharding_strategy, bucketing_strategy=self.bucketing_strategy, **self._fsdp_kwargs, ) # update the compile data state cd.fn = fsdp_module cd.process_group_for_ddp = fsdp_module.process_group_for_ddp return module else: module = thunder.distributed.fsdp( module, device=self.root_device, sharding_strategy=self.sharding_strategy, bucketing_strategy=self.bucketing_strategy, **self._fsdp_kwargs, ) if not self.jit: return module return thunder.jit(module, executors=self.executors) @override def module_to_device(self, module: Module) -> None: pass @override def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() if empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP stack.enter_context(torch.device("meta")) stack.enter_context(precision_init_ctx) stack.enter_context(module_sharded_ctx) return stack @override def module_sharded_context(self) -> ContextManager: return nullcontext() @override def all_reduce( self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> Tensor: if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor @override def barrier(self, *args: Any, **kwargs: Any) -> None: if not _distributed_is_initialized(): return if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=[self.root_device.index]) else: torch.distributed.barrier() @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not _distributed_is_initialized(): return obj obj = [obj] torch.distributed.broadcast_object_list(obj, src) return obj[0] @override def clip_gradients_norm( self, module: Module, optimizer: Optimizer, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0, error_if_nonfinite: bool = True, ) -> Tensor: raise NotImplementedError @override def save_checkpoint( self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: if storage_options is not None: raise TypeError( "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" " `FSDPStrategy` does not use the `CheckpointIO`." ) if filter is not None: raise NotImplementedError("Filtering checkpoint paths is not implemented") # broadcast the path from rank 0 to ensure all the states are saved in a common path path = Path(self.broadcast(path)) if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, save modules = [module for module in state.values() if has_fsdp_modules(module)] if len(modules) == 0: raise ValueError( "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." ) if len(modules) > 1: raise ValueError( "Found multiple FSDP models in the given state. Saving checkpoints with FSDP is" " currently limited to a single model per checkpoint. To save multiple models, call the" " save method for each model separately with a different path." ) if self._state_dict_type == "sharded": if _is_full_checkpoint(path): path.unlink() path.mkdir(parents=True, exist_ok=True) options = StateDictOptions(full_state_dict=False, cpu_offload=True, rank0_only=False) converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) save(converted_state, path) if self.global_rank == 0: torch.save(metadata, path / _METADATA_FILENAME) elif self._state_dict_type == "full": if _is_sharded_checkpoint(path): shutil.rmtree(path) options = StateDictOptions(full_state_dict=True, cpu_offload=True, rank0_only=True) converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) converted_state.update(metadata) if self.global_rank == 0: torch.save(converted_state, path) else: raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override def load_checkpoint( self, path: _PATH, state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, ) -> Dict[str, Any]: if not state: raise ValueError( f"Got `FSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least" " a model instance to reload is required. Pass it in like so:" " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`" ) # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(path)) from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, load, load_model_state_dict if isinstance(state, Module): if not _is_full_checkpoint(path): raise ValueError( "Failed to load checkpoint directly into the model. The given path must be a single file" f" containing the full state dict: {path}" ) state_dict = torch.load(str(path), mmap=True, map_location="cpu") options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) load_model_state_dict(state_dict, _unwrap_tom(state), options, self.local_rank) return {} if isinstance(state, Optimizer): raise NotImplementedError( "Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy." ) modules = {key: module for key, module in state.items() if has_fsdp_modules(module)} if len(modules) == 0: raise ValueError( "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." ) if len(modules) > 1: raise ValueError( "Found multiple FSDP models in the given state. Loading checkpoints with FSDP is" " currently limited to a single model per checkpoint. To load multiple models, call the" " load method for each model separately with a different path." ) optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)} module_key, module = list(modules.items())[0] module = _unwrap_tom(module) if _is_sharded_checkpoint(path): options = StateDictOptions(full_state_dict=False, cpu_offload=True, strict=strict, rank0_only=False) # Load the DCP state dict, which requires a holder state dict converted_state, _ = _get_state_dict(state, None, options, self.local_rank) load(converted_state, path) load_model_state_dict(converted_state[module_key], module, options, self.local_rank) # Load metadata (anything not a module or optimizer) metadata = torch.load(path / _METADATA_FILENAME) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) for key in requested_metadata_keys: if key not in metadata: continue state[key] = metadata.pop(key) # return the remaining metadata that wasn't requested as part of `state` return metadata if _is_full_checkpoint(path): options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) if not options.rank0_only or self.local_rank == 0: map_location = "cpu" if options.cpu_offload else None checkpoint = torch.load(str(path), mmap=True, map_location=map_location) load_model_state_dict(checkpoint[module_key], module, options, self.local_rank) else: checkpoint = {} requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) # Load metadata (anything not a module or optimizer) _move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys) # return the remaining metadata that wasn't requested as part of `state` return checkpoint raise ValueError( f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" " directory with FSDP checkpoint shards, or a single file with a full checkpoint." ) def _setup_distributed(self) -> None: reset_seed() self._set_world_ranks() process_group_backend = _get_default_process_group_backend_for_device(self.root_device) assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, process_group_backend) def _set_world_ranks(self) -> None: if self.cluster_environment is not None: self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _is_sharded_checkpoint(path: Path) -> bool: """A heuristic check to determine whether the path points to a directory with checkpoint shards.""" return path.is_dir() and (path / _METADATA_FILENAME).is_file() def _is_full_checkpoint(path: Path) -> bool: return path.is_file() def _get_state_dict( state: Dict[str, Any], filter: Optional[Dict[str, Callable[[str, Any], bool]]], options: "StateDictOptions", rank: int, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: from thunder.distributed.checkpoint import get_model_state_dict # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata converted_state: Dict[str, Any] = {} metadata: Dict[str, Any] = {} for key, obj in state.items(): converted: Any if isinstance(obj, Module): converted = get_model_state_dict(_unwrap_tom(obj), options, rank) target_dict = converted_state elif isinstance(obj, Optimizer): # TODO: optimizer support converted = obj.state_dict() target_dict = converted_state else: # everything not a module or optimizer is considered metadata converted = obj.state_dict() if isinstance(obj, _Stateful) else obj target_dict = metadata _apply_filter(key, filter or {}, converted, target_dict) return converted_state, metadata def _unwrap_tom(obj: object) -> object: # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder from thunder import ThunderModule if isinstance(obj, ThunderModule): return obj._model return obj ================================================ FILE: extensions/thunder/unsloth/__init__.py ================================================ ================================================ FILE: extensions/thunder/unsloth/executor.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys from pathlib import Path from typing import Optional, Tuple import torch from torch import Tensor import litgpt.model from litgpt.constants import _THUNDER_AVAILABLE from litgpt.model import LLaMAMLP as OriginalLLaMAMLP from thunder.core.proxies import TensorProxy from thunder.core.transforms import get_grad, mean_backward, put_grads from thunder.extend import OperatorExecutor, register_executor from thunder.torch import ne, sum, true_divide if _THUNDER_AVAILABLE: import thunder import thunder.torch as ltorch sys.path.append(str(Path(__file__).parent)) import kernels unsloth_ex = OperatorExecutor("unsloth", version="0.1") register_executor(unsloth_ex) """ ==================== Cross Entropy Loss ==================== """ def unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tuple[TensorProxy, TensorProxy]: return ( TensorProxy( shape=(logits.shape[0],), # the cross entropy kernel only supports float32 dtype=thunder.dtypes.float32, device=logits.device, requires_grad=logits.requires_grad, ), TensorProxy(shape=(logits.shape[0],), dtype=thunder.dtypes.float32, device=logits.device, requires_grad=False), ) unsloth_cross_entropy = unsloth_ex.register_operator( "unsloth_cross_entropy", meta=unsloth_cross_entropy_meta, fn=kernels.cross_entropy_loss._cross_entropy_forward_impl ) def unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor: # clone() because the kernel writes the grads in the logits return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels) def unsloth_cross_entropy_backward_meta( dlosses: TensorProxy, logits: TensorProxy, logsumexp: TensorProxy, labels: TensorProxy ) -> TensorProxy: return thunder.TensorProxy(like=logits) unsloth_cross_entropy_backward = unsloth_ex.register_operator( "unsloth_cross_entropy_backward", meta=unsloth_cross_entropy_backward_meta, fn=unsloth_cross_entropy_backward_impl ) def unsloth_cross_entropy_checker( logits: TensorProxy, labels: TensorProxy, weight: Optional[TensorProxy] = None, size_average: Optional[bool] = None, ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", label_smoothing: float = 0.0, ) -> bool: return ( weight is None and size_average is None and reduce is None and reduction in ("none", "mean") and ignore_index == -100 and label_smoothing == 0.0 and logits.device.type == "cuda" and labels.device.type == "cuda" ) def cross_entropy_to_unsloth( logits: TensorProxy, labels: TensorProxy, weight: Optional[TensorProxy] = None, size_average: Optional[bool] = None, ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", label_smoothing: float = 0.0, ) -> Tuple[TensorProxy, TensorProxy]: loss, logsumexp = unsloth_cross_entropy(logits, labels) if reduction == "mean": # "mean" reduction is not part of the kernel # TODO: this doesn't consider that all elements could be masked, causing a division by 0 n_items = sum(ne(labels, -100)) loss = true_divide(sum(loss), n_items) elif reduction != "none": raise NotImplementedError(reduction) return loss, logsumexp def unsloth_cross_entropy_grad( logits: TensorProxy, labels: TensorProxy, weight: Optional[TensorProxy] = None, size_average: Optional[bool] = None, ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", label_smoothing: float = 0.0, ) -> TensorProxy: loss, logsumexp = cross_entropy_to_unsloth(**locals()) grad = get_grad(loss) if reduction == "mean": grad = mean_backward(logsumexp.ndim, logsumexp.shape, (0,), grad) logits_grad = unsloth_cross_entropy_backward(grad, logits, labels, logsumexp) put_grads((logits,), (logits_grad,)) return loss # registers as cross entropy implementation, including the execution transform and now a grad transform unsloth_ex.register_implementation( ltorch.cross_entropy, checker=unsloth_cross_entropy_checker, execution_transform=lambda *args: cross_entropy_to_unsloth(*args)[0], grad_transform=unsloth_cross_entropy_grad, ) """ ========= RMSNorm ========= The RMSNorm kernel is not integrated because it's not numerically equal and it doesn't compute the gradient for the weight, just for the input. """ """ ======== SwiGLU ======== """ def swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor: return torch.nn.functional.silu(e) * g class ThunderLLaMAMLP(OriginalLLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = swiglu(x_fc_1, x_fc_2) return self.proj(x) litgpt.model.LLaMAMLP = ThunderLLaMAMLP def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy: return TensorProxy(like=e) litgpt_swiglu = unsloth_ex.register_operator("litgpt_swiglu", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu) unsloth_swiglu_forward = unsloth_ex.register_operator( "unsloth_swiglu_forward", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args) ) def unsloth_swiglu_backward_meta(DW: TensorProxy, e: TensorProxy, g: TensorProxy) -> Tuple[TensorProxy, TensorProxy]: return TensorProxy(like=g), TensorProxy(like=e) def unsloth_swiglu_backward_fn(DW: Tensor, e: Tensor, g: Tensor) -> Tuple[Tensor, Tuple]: B, T, n_embd = e.shape e = e.view(-1, n_embd) g = g.view(-1, n_embd) DW, e, g = kernels.swiglu_DWf_DW_dfg_kernel(DW, e, g) e = e.view(B, T, n_embd) g = g.view(B, T, n_embd) return g, e unsloth_swiglu_backward = unsloth_ex.register_operator( "unsloth_swiglu_backward", meta=unsloth_swiglu_backward_meta, fn=unsloth_swiglu_backward_fn ) def swiglu_to_unsloth_checker(e: TensorProxy, g: TensorProxy) -> bool: return e.device.type == "cuda" and g.device.type == "cuda" def unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy: h = unsloth_swiglu_forward(**locals()) grad = get_grad(h) e_grad, g_grad = unsloth_swiglu_backward(grad, e, g) put_grads((e, g), (e_grad, g_grad)) return h unsloth_ex.register_implementation( litgpt_swiglu, checker=swiglu_to_unsloth_checker, execution_transform=unsloth_swiglu_forward, grad_transform=unsloth_swiglu_grad, ) """ ====== RoPE ====== """ def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy: return TensorProxy(like=x) apply_rope = unsloth_ex.register_operator( "litgpt_apply_rope", like=apply_rope_meta, fn=litgpt.model.apply_rope, replaces=litgpt.model.apply_rope ) def unsloth_apply_rope_meta( Q: TensorProxy, cos: TensorProxy, sin: TensorProxy ) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]: batch, n_heads, seq_len, head_dim = Q.shape assert seq_len <= cos.shape[-2] BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2) div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE) n_groups = div + (mod != 0) return TensorProxy(like=Q), cos, sin, n_groups, BLOCK_SIZE, num_warps unsloth_apply_rope = unsloth_ex.register_operator( "unsloth_apply_rope", meta=unsloth_apply_rope_meta, fn=kernels._rope_embedding_forward_impl ) def unsloth_apply_rope_backward_meta( dY: TensorProxy, cos: TensorProxy, sin: TensorProxy, n_groups: int, BLOCK_SIZE: int, num_warps: int ) -> TensorProxy: return TensorProxy(like=dY) unsloth_apply_rope_backward = unsloth_ex.register_operator( "unsloth_apply_rope_backward", meta=unsloth_apply_rope_backward_meta, fn=kernels._rope_embedding_backward_impl ) def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool: return len(x.shape) == 4 and x.device.type == "cuda" and cos.device.type == "cuda" and sin.device.type == "cuda" def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy: Q, cos, sin, n_groups, BLOCK_SIZE, num_warps = unsloth_apply_rope(x, cos, sin) dY = get_grad(Q) dX = unsloth_apply_rope_backward(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps) put_grads((x,), (dX,)) return Q unsloth_ex.register_implementation( apply_rope, checker=apply_rope_to_unsloth_checker, execution_transform=lambda *args: unsloth_apply_rope(*args)[0], grad_transform=unsloth_apply_rope_grad, ) ================================================ FILE: extensions/thunder/unsloth/kernels/__init__.py ================================================ from .cross_entropy_loss import _cross_entropy_backward_impl, _cross_entropy_forward_impl # noqa: F401 from .rope_embedding import ROPE_GROUP_SIZE, _rope_embedding_backward_impl, _rope_embedding_forward_impl # noqa: F401 from .swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel # noqa: F401 from .utils import calculate_settings # noqa: F401 ================================================ FILE: extensions/thunder/unsloth/kernels/cross_entropy_loss.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from litgpt.constants import _TRITON_AVAILABLE from .utils import MAX_FUSED_SIZE, calculate_settings if _TRITON_AVAILABLE: import triton import triton.language as tl @triton.jit def _cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] Pi = exp(xi) / sum(exp(xi)) CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ] = -y [ x - log[sum(exp(x))] ] = y * (log[sum(exp(x))] - x) If y == 0: CE_i = 0 If y == 1: CE_i = logsumexp - x logsumexp is also stable Take y = log[sum(exp(x))] exp(y) = sum(exp(x)) exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x exp(y) = exp(c)*sum(exp(x - c)) y = log(exp(c)*sum(exp(x - c))) y = c + log[sum(exp(x - c))] This means we can set c = max(x) to make sure exp(x - c) always is exp(x - max(x)). This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) loss = logsumexp - x else: loss = 0.0 tl.store(logsumexp_ptr, logsumexp) tl.store(loss_ptr, loss) pass @triton.jit def _chunked_cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, N_CHUNKS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ 256K vocab divided in 4 chunks |-65536-| |-65536-| |-65536-| |-65536-| |-------| |-------| |-------| |-------| |-------| |-------| |-------| |-------| If y == 0: CE_i = 0 If y == 1: CE_i = logsumexp - x Notice we can do logsumexp for each chunk and then logsumexp[chunk_sum(logsumexp)] == logsumexp chunk_sum = log[chunk_sum(logsumexp)] = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))] = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])] = log[sum(exp(a)) + ... + sum(exp(z))] = logsumexp(x) This means we can perform a logsumexp for each chunk, then do a final logsumexp reduction! Ie do: logsumexp(chunked_logsumexp) - x """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if chunk_idx == 0: # logsumexp(chunked_logsumexp) - x # Do the -x separately if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) loss = -1.0 * x else: loss = 0.0 tl.store(loss_ptr, loss) pass tl.store(logsumexp_ptr, logsumexp) pass @triton.jit def _cross_entropy_backward( logits_ptr, logits_row_stride, dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) dC/dx = d/dx (y * log[sum(exp(x))] - x * y) From https://en.wikipedia.org/wiki/LogSumExp d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x) dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y) dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick dC/dx = y * exp[x - logsumexp] - d/dx (x * y) If y == 0: dC/dx = 0 If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1 If y == 1 and x != label: dC/dx = exp[x - logsumexp] """ row_idx = tl.program_id(0) block_idx = tl.program_id(1) logits_ptr += row_idx * logits_row_stride.to(tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) if label_idx != -100: dloss = tl.load(dloss_ptr) else: dloss = 0.0 x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) logsumexp = tl.load(logsumexp_ptr + row_idx) y = tl.exp(x - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 y, # exp(x - logsumexp) ) # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. tl.store(logits_ptr + col_offsets, dloss * y, mask=mask) pass def _cross_entropy_forward_impl(logits, labels): n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype=torch.float32, device="cuda") if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype=torch.float32, device="cuda") _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), losses, logsumexp, labels, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) else: # For large vocabs > 65336 like Gemma 256K logsumexp = torch.empty( ( n_rows, n_chunks, ), dtype=torch.float32, device="cuda", ) _chunked_cross_entropy_forward[ ( n_rows, n_chunks, ) ]( logits, logits.stride(0), losses, logsumexp, labels, VOCAB_SIZE=vocab_size, N_CHUNKS=n_chunks, BLOCK_SIZE=MAX_FUSED_SIZE, num_warps=32, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum losses += logsumexp losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out! return losses, logsumexp def _cross_entropy_backward_impl(dlosses, logits, logsumexp, labels): n_rows, vocab_size = logits.shape BLOCK_SIZE = 4096 div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks = div + (mod != 0) _cross_entropy_backward[ ( n_rows, n_blocks, ) ]( logits, logits.stride(0), dlosses, dlosses.stride(0), logsumexp, labels, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, num_warps=8, ) return logits ================================================ FILE: extensions/thunder/unsloth/kernels/rope_embedding.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from litgpt.constants import _TRITON_AVAILABLE from .utils import calculate_settings if _TRITON_AVAILABLE: import triton import triton.language as tl ROPE_GROUP_SIZE = 4 @triton.heuristics( { "BACKWARD_PASS": lambda args: args["BACKWARD_PASS"], } ) @triton.jit def _rope_embedding( Q, Q_row_stride, cos, cos_row_stride, sin, sin_row_stride, seqlen, head_dim: tl.constexpr, n_heads: tl.constexpr, BACKWARD_PASS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ROPE_GROUP_SIZE: tl.constexpr = 4, ): """ Calculates the RoPE Embedding quickly RoPE is Q * cos + rotate_half(Q) * sin See our blog post for more info """ row_position = tl.program_id(0) group_head_position = tl.program_id(1) col_offsets = tl.arange(0, BLOCK_SIZE) half_head_dim = head_dim // 2 mask = col_offsets < half_head_dim sin1 = tl.load(sin + (row_position % seqlen) * sin_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0) cos1 = tl.load(cos + (row_position % seqlen) * cos_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0) if BACKWARD_PASS: # See our blog post for more info. sin1 = -sin1 pass # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8 head_start = group_head_position * ROPE_GROUP_SIZE head_end = min((head_start + ROPE_GROUP_SIZE), n_heads) # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238) for k in range(head_start, head_end): offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim # For Gemma - sometimes RoPE must be done in float32 and not bfloat16 Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype) Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype) tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask) tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask) pass pass def _rope_embedding_forward_impl(Q, cos, sin): Q = Q.transpose(1, 2).clone() cos, sin = cos.squeeze(), sin.squeeze() batch, seq_len, n_heads, head_dim = Q.shape Q = Q.reshape(batch * seq_len, n_heads * head_dim) n_rows, n_cols = Q.shape assert seq_len <= cos.shape[0] # [TODO] Changing blocksize to head_dim//2 seems to have # some concurrency / un-deterministic issues. BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2) # (head_dim//2) # group_size = 4 # 4 or 8, too large group_size can hurt performance. div, mod = divmod(n_heads, ROPE_GROUP_SIZE) n_groups = div + (mod != 0) _rope_embedding[ ( n_rows, n_groups, ) ]( Q, Q.stride(0), cos, cos.stride(0), sin, sin.stride(0), seq_len, head_dim, n_heads, BACKWARD_PASS=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) Q = Q.view(batch, seq_len, n_heads, head_dim) Q = Q.transpose(1, 2) return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps def _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps): dY = dY.transpose(1, 2) batch, seq_len, n_heads, head_dim = dY.shape dY = dY.reshape(batch * seq_len, n_heads * head_dim) # Must be reshape not view n_rows, n_cols = dY.shape _rope_embedding[ ( n_rows, n_groups, ) ]( dY, dY.stride(0), cos, cos.stride(0), sin, sin.stride(0), seq_len, head_dim, n_heads, BACKWARD_PASS=True, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) dY = dY.view(batch, seq_len, n_heads, head_dim) dY = dY.transpose(1, 2) return dY ================================================ FILE: extensions/thunder/unsloth/kernels/swiglu.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from litgpt.constants import _TRITON_AVAILABLE if _TRITON_AVAILABLE: import triton import triton.language as tl @triton.jit def _fg_kernel( e, g, h, n_elements, BLOCK_SIZE: tl.constexpr, ): block_idx = tl.program_id(0) offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32) g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32) # f = e * sigmoid(e) f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row)) f_row = f_row.to(g_row.dtype) # Exact copy from HF # h = f * g h_row = f_row * g_row # Store h tl.store(h + offsets, h_row, mask=mask) pass def swiglu_fg_kernel(e, g): batch, seq_len, hd = e.shape n_elements = e.numel() h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device="cuda") grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _fg_kernel[grid]( e, g, h, n_elements, BLOCK_SIZE=1024, ) return h pass @triton.jit def _DWf_DW_dfg_kernel( DW, e, g, n_elements, BLOCK_SIZE: tl.constexpr, ): """ e = e.float() se = 1.0 / (1.0 + torch.exp(-e)) f = (se * e).to(dtype) h = f * g df = DW * f dg = DW * g de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) """ block_idx = tl.program_id(0) offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32) e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32) g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32) # e = e.float() # se = 1.0 / (1.0 + torch.exp(-e)) se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row)) # f = (se * e).to(dtype) f_row = se_row * e_row f_row = f_row.to(DW_row.dtype) # h = f * g h_row = f_row * g_row # df = DW * f df_row = DW_row * f_row # dg = DW * g dg_row = DW_row * g_row # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row)) de_row = de_row.to(DW_row.dtype) # Store derivatives in buffers tl.store(DW + offsets, h_row, mask=mask) # h = f * g tl.store(e + offsets, df_row, mask=mask) # df = DW * f tl.store(g + offsets, de_row, mask=mask) # de pass def swiglu_DWf_DW_dfg_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _DWf_DW_dfg_kernel[grid]( DW, e, g, n_elements, BLOCK_SIZE=1024, ) return DW, e, g pass ================================================ FILE: extensions/thunder/unsloth/kernels/utils.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from litgpt.constants import _TRITON_AVAILABLE if _TRITON_AVAILABLE: import triton MAX_FUSED_SIZE = 65536 # 2**16 next_power_of_2 = triton.next_power_of_2 def calculate_settings(n): BLOCK_SIZE = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 return BLOCK_SIZE, num_warps pass ================================================ FILE: extensions/xla/README.md ================================================ # TPU support This project utilizes [`Fabric`](https://lightning.ai/docs/fabric/stable), which supports TPUs via [PyTorch XLA](https://github.com/pytorch/xla). > [!NOTE] > This guide assumes that you have already set-up your [Google Cloud environment](https://cloud.google.com/run/docs/setup). To set up a Google Cloud instance with a TPU v4 VM, run the following commands: ```shell gcloud compute tpus tpu-vm create litgpt --version=tpu-vm-v4-base --accelerator-type=v4-8 --zone=us-central2-b gcloud compute tpus tpu-vm ssh litgpt --zone=us-central2-b ``` You can also choose a different TPU type. To do so, change the `version`, `accelerator-type`, and `zone` arguments. Find all regions and zones [here](https://cloud.google.com/tpu/docs/regions-zones).
Multihost caveats TPU v4-8 uses a single host. SSH'ing into the machine and running commands manually will only work when using a single host (1 slice in the TPU pod). In multi-host environments, such as larger TPU pod slices, it's necessary to launch all commands on all hosts simultaneously to avoid hangs. For local development, it is advisable to upload a zip file containing all your current changes and execute it inside the VM from your personal computer: ```shell # Zip the local directory, excluding large directories from the zip. You may want to keep them. zip -r local_changes.zip . -x ".git/*" "checkpoints/*" "data/*" "out/*" # Copy the .zip file to the TPU VM gcloud compute tpus tpu-vm scp --worker=all local_changes.zip "litgpt:~" # Unzip on each host gcloud compute tpus tpu-vm ssh litgpt --worker=all --command="cd ~; unzip -q -o local_changes.zip" # Example of a typical workflow gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash install_dependencies.sh" gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash prepare_checkpoints.sh" gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash run_desired_script.sh" # This will allow you to kill all python processes on all workers gcloud compute tpus tpu-vm ssh tmp --worker=all --command="pkill -e python" ``` Notice how the commands to install the environment and prepare checkpoints need to be run on all workers, since the filesystem for each worker (host) is not shared. For the rest of this tutorial, it will be assumed that it is being run on a single host for simplicity.
Once inside the machine, clone the repository and install the dependencies: ```shell git clone https://github.com/Lightning-AI/litgpt cd litgpt pip install . ``` Install Optimized BLAS: ```shell sudo apt update sudo apt install libopenblas-dev ``` Since LitGPT requires a torch version newer than torch 2.0.0, manually install nightly builds of torch and torch_xla: ```shell pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl ``` While computations will run by default using the new PjRT runtime, it is recommended to set the following environment variables: ```shell export ALLOW_MULTIPLE_LIBTPU_LOAD=1 export PJRT_DEVICE=TPU ``` > [!NOTE] > An extensive guide on setup and available options can be found [here](https://cloud.google.com/tpu/docs/v4-users-guide). Since a new machine was created, you may need to download pretrained weights. They can be copied to the machine using `gcloud compute tpus tpu-vm scp`, or you can follow the steps described in our [downloading guide](../../tutorials/download_model_weights.md). It is also recommended to set up a persistent disk from which to load checkpoints. Follow [this guide](https://cloud.google.com/tpu/docs/setup-persistent-disk#setting_up_a_tpu_vm_and_a_persistent_disk) to do so. Read-write disks are not supported in multihost VM setups, so persistent disks cannot be used to save checkpoints in that case. Persistent disks can still be useful in read-only mode to load pretrained weights before finetuning or inference. In multihost settings, FSDP will save checkpoint shards per host and consolidate them into a single checkpoint. For safekeeping, it is recommended to upload the consolidated checkpoints to a Google Cloud bucket. Alternatively, you can use the `scp` command to transfer these checkpoints from the TPU VM periodically, although this is not implemented in our scripts. ## Inference This project provides custom versions of the regular recipes to run with XLA in the `xla` directory. To generate text, use the following command: ```shell python3 xla/generate/base.py --prompt "Hello, my name is" --num_samples 3 ``` For the first generation, this command will take around 17 seconds as XLA needs to compile the graph. Subsequent generations will take around 2 seconds. ## Fine-tuning To get started fine-tuning Falcon 7B with adapter, run the following command: ```shell python3 xla/scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b python3 xla/finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true ```
Multihost caveats This script is configured to save "full" checkpoints, which isn't possible on multihost TPU VMs. Here's how you can consolidate them together into a single one after training with `state_dict_type="sharded"`: ```shell path_to_shards="out/adapter/alpaca/lit_model_adapter_finetuned" mkdir -p $path_to_shards workers=4 # 4 hosts for ((i = 0; i < workers; i++)); do # aggregate all shards locally gcloud compute tpus tpu-vm scp --worker=$i "litgpt:${path_to_shards}/*" "${path_to_shards}/" --zone us-central2-b done # copy all shards to all workers gcloud compute tpus tpu-vm scp --worker=all ${path_to_shards}/* "litgpt:${path_to_shards}/" --zone us-central2-b # consolidate the shards in each worker gcloud compute tpus tpu-vm ssh tmp --worker=all --command="python -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix ${path_to_shards}/checkpoint --ckpt_suffix '_rank-*-of-*.pth' --save_path ${path_to_shards}.pth" --zone us-central2-b ```
Since the TPU VM host RAM is limited (200 GB), we implement a technique to sequentially load and shard the checkpoint that can be enabled by setting `reduce_cpu_memory_usage_during_load = True`. This is necessary to load falcon-40b. To generate text with the adapter fine-tuned model weights, use the following command: ```shell python3 xla/generate/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --adapter_path out/adapter/alpaca/lit_model_adapter_finetuned.pth ``` > **Warning** > Remember to delete your instance when you are done. > > ```shell > gcloud compute tpus tpu-vm delete litgpt --zone=us-central2-b > ``` ## Computational Performance Using the [adapter finetuning script](finetune/adapter.py) and XLA's FSDP implementation, a 49.57% MFU was achieved with Falcon 7B on a v4-32 (micro batch size 7), and a 39.67% MFU was achieved with Falcon 40B on a v4-512 (micro batch size 3) at a fixed 1034 maximum sequence length. Since the TPU VM host has limited system memory (RAM) compared to device memory (HBM), specific techniques were implemented to limit peak RAM usage when loading the model and pretrained weights before sharding, as well as when saving sharded checkpoints. A v4 chip has 32 GiB HBM, so with 4 devices per host (4 * 32 = 128 GiB HBM), each host has 188 GiB RAM, which is shared across the devices. Therefore, any RAM allocation over 188/4 = 47 GiB would exceed the host's RAM capacity. A ~24B parameter model on CPU (with half precision) would be the largest possible model under this setup without the techniques used in our scripts. ================================================ FILE: extensions/xla/__init__ ================================================ import sys from pathlib import Path # support running without installing as a package, adding extensions to the Python path wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) ================================================ FILE: extensions/xla/finetune/__init__ ================================================ ================================================ FILE: extensions/xla/finetune/adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import sys import time from pathlib import Path from typing import Dict, List, Tuple import lightning as L import torch import torch_xla.core.xla_model as xm from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.loggers import CSVLogger from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.utilities import ThroughputMonitor, measure_flops from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable from litgpt.tokenizer import Tokenizer from litgpt.utils import check_valid_checkpoint_dir, chunked_cross_entropy, estimate_flops, lazy_load, num_parameters # support running without installing as a package wd = Path(__file__).parents[3].resolve() sys.path.append(str(wd)) from xla.generate.base import generate # noqa: E402 from xla.scripts.prepare_alpaca import generate_prompt # noqa: E402 from xla.utils import rank_print, sequential_load_and_fsdp_wrap # noqa: E402 eval_interval = 200 save_interval = 200 eval_iters = 100 eval_max_new_tokens = 100 log_interval = 1 devices = XLAAccelerator.auto_device_count() # the state of very large models will not fit on the system RAM, this flag can alleviate it by loading it on each rank # sequentially reduce_cpu_memory_usage_during_load = False # Hyperparameters learning_rate = 3e-3 batch_size = 4 micro_batch_size = batch_size gradient_accumulation_iters = batch_size // micro_batch_size assert gradient_accumulation_iters > 0 epoch_size = 50000 # train dataset size num_epochs = 5 max_iters = num_epochs * (epoch_size // micro_batch_size) // devices weight_decay = 0.02 warmup_steps = 2 * (epoch_size // micro_batch_size) // devices // gradient_accumulation_iters # 2 epochs hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} def setup( *, data_dir: Path = Path("data/alpaca"), checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"), out_dir: Path = Path("out/adapter/alpaca"), precision: str = "bf16-true", ) -> None: if devices > 1: strategy = XLAFSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, state_dict_type="full", # change to "sharded" in multi-host environments where the filesystem is not shared sequential_save=True, ) else: strategy = "auto" logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) rank_print(fabric, hparams) fabric.launch(main, data_dir, checkpoint_dir, out_dir) def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None: check_valid_checkpoint_dir(checkpoint_dir) fabric.seed_everything(1337) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) train_data = torch.load(data_dir / "train.pt") val_data = torch.load(data_dir / "test.pt") config = Config.from_name(name=checkpoint_dir.name, adapter_start_layer=0) checkpoint_path = checkpoint_dir / "lit_model.pth" rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") if reduce_cpu_memory_usage_during_load: model = sequential_load_and_fsdp_wrap(fabric, lambda: GPT(config), checkpoint_path) else: with fabric.init_module(empty_init=False): model = GPT(config) checkpoint = lazy_load(checkpoint_path) # strict=False because missing keys due to adapter weights not contained in state dict model.load_state_dict(checkpoint, strict=False) model = fabric.setup_module(model) # mark as trainable only after sharding due to https://github.com/pytorch/xla/pull/5484 mark_only_adapter_as_trainable(model) # these are not correct in the sharding case rank_print(fabric, f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") rank_print(fabric, f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(trainable_params, lr=learning_rate) optimizer = fabric.setup_optimizers(optimizer) fabric.seed_everything(1337 + fabric.global_rank) train_time = time.perf_counter() train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir) rank_print(fabric, f"Training time: {(time.perf_counter() - train_time):.2f}s") # Save the final checkpoint at the end of training save_path = out_dir / "lit_model_adapter_finetuned.pth" save_adapter_checkpoint(fabric, model, save_path) def train( fabric: L.Fabric, model: GPT, optimizer: torch.optim.Optimizer, train_data: List[Dict], val_data: List[Dict], checkpoint_dir: Path, out_dir: Path, ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length = get_longest_seq_length(train_data) model.max_seq_length = longest_seq_length # to avoid recompilation, this script is configured to pad batches to the `longest_seq_length` fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) with torch.device("meta"): meta_model = GPT(model.config) mark_only_adapter_as_trainable(meta_model) # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, # consider passing `flops_per_batch=estimated_flops` instead estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size rank_print(fabric, f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") # this assumes that all samples have a fixed length equal to the longest sequence length # which is most likely false during finetuning x = torch.randint(0, 1, (micro_batch_size, longest_seq_length)) forward_fn = lambda: meta_model(x) # noqa: F821 loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821 measured_flops = measure_flops(meta_model, forward_fn, loss_fn) rank_print(fabric, f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x throughput = ThroughputMonitor(fabric, window_size=50) step_count = 0 total_t0 = time.perf_counter() xm.mark_step() for iter_num in range(1, max_iters + 1): if step_count <= warmup_steps: # linear warmup lr = learning_rate * step_count / warmup_steps for param_group in optimizer.param_groups: param_group["lr"] = lr iter_t0 = time.perf_counter() input_ids, targets = get_batch(fabric, train_data, longest_seq_length) is_accumulating = iter_num % gradient_accumulation_iters != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) xm.mark_step() # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / gradient_accumulation_iters) xm.mark_step() if not is_accumulating: optimizer.step() optimizer.zero_grad() step_count += 1 else: xm.mark_step() if iter_num % log_interval == 0: t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * micro_batch_size, lengths=iter_num * micro_batch_size * longest_seq_length, flops=measured_flops * log_interval, ) throughput.compute_and_log(step=iter_num) rank_print( fabric, f"iter {iter_num} step {step_count}:" # uncomment to print the loss. this will considerably slow down the iteration times # + f" loss {loss.item():.4f}," + f" iter time: {(t1 - iter_t0) * 1000:.2f}ms" + (" (optimizer.step)" if not is_accumulating else ""), ) if not is_accumulating and step_count % eval_interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length) t1 = time.perf_counter() - t0 rank_print(fabric, f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms") fabric.barrier() if not is_accumulating and step_count % save_interval == 0: checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth" save_adapter_checkpoint(fabric, model, checkpoint_path) # xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, longest_seq_length: int ) -> torch.Tensor: rank_print(fabric, "Validating ...") model.eval() losses = torch.zeros(eval_iters) xm.mark_step() for k in range(eval_iters): input_ids, targets = get_batch(fabric, val_data, longest_seq_length) logits = model(input_ids) xm.mark_step() losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() # produce an example: instruction = "Recommend a movie for me to watch during the weekend and explain the reason." rank_print(fabric, instruction) sample = {"instruction": instruction, "input": ""} prompt = generate_prompt(sample) encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8) model.clear_kv_cache() output = tokenizer.decode(output) rank_print(fabric, output) model.train() return val_loss def get_batch(fabric: L.Fabric, data: List[Dict], longest_seq_length: int) -> Tuple[torch.Tensor, torch.Tensor]: ix = torch.randint(len(data), (micro_batch_size,)) input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] labels = [data[i]["labels"].type(torch.int64) for i in ix] def pad_right(x, pad_id): # pad right using a fixed longest sequence length to avoid recompilation n = longest_seq_length - len(x) return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) x, y = fabric.to_device((x, y)) return x, y def get_longest_seq_length(data: List[Dict]) -> int: # find out the minimum max_seq_length required during fine-tuning (saves memory!) return max(len(d["input_ids"]) for d in data) def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: rank_print(fabric, f"Saving adapter weights to {str(file_path)!r}") fabric.save(file_path, {"model": model}, filter={"model": adapter_filter}) if __name__ == "__main__": from jsonargparse import CLI CLI(setup) ================================================ FILE: extensions/xla/generate/__init__ ================================================ ================================================ FILE: extensions/xla/generate/adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time from pathlib import Path from typing import Optional import lightning as L from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies import XLAFSDPStrategy from litgpt import Tokenizer from litgpt.adapter import GPT, Block, Config from litgpt.prompts import Alpaca from litgpt.utils import check_valid_checkpoint_dir, lazy_load # support running without installing as a package wd = Path(__file__).parents[3].resolve() sys.path.append(str(wd)) from xla.generate.base import generate # noqa: E402 from xla.utils import rank_print # noqa: E402 def setup( prompt: str = "What food do llamas eat?", *, input: str = "", sys_prompt: Optional[str] = None, adapter_path: Path = Path("out/adapter/alpaca/lit_model_adapter_finetuned.pth"), checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"), max_new_tokens: int = 100, top_k: Optional[int] = 50, temperature: float = 0.8, precision: str = "bf16-true", ) -> None: """Generates a response based on a given instruction and an optional input. This script will only work with checkpoints from the instruction-tuned Adapter model. See `xla/finetune/adapter.py`. Args: prompt: The prompt/instruction (Alpaca style). input: Optional input (Alpaca style). sys_prompt: Optional system prompt. adapter_path: Path to the checkpoint with trained adapter weights, which are the output of `xla/finetune/adapter.py`. checkpoint_dir: The path to the checkpoint folder with pretrained model weights. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. precision: Indicates the Fabric precision setting to use. """ devices = XLAAccelerator.auto_device_count() strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else "auto" fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch(main, prompt, input, sys_prompt, adapter_path, checkpoint_dir, max_new_tokens, top_k, temperature) def main( fabric: L.Fabric, prompt: str, input: str, sys_prompt: Optional[str], adapter_path: Path, checkpoint_dir: Path, max_new_tokens: int, top_k: Optional[int], temperature: float, ) -> None: check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml", adapter_start_layer=0) checkpoint_path = checkpoint_dir / "lit_model.pth" rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) rank_print(fabric, f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) t0 = time.perf_counter() checkpoint = lazy_load(checkpoint_path) adapter_checkpoint = lazy_load(adapter_path) checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) model.load_state_dict(checkpoint) rank_print(fabric, f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) model.eval() model = fabric.setup_module(model) tokenizer = Tokenizer(checkpoint_dir) # TODO: Load prompt style from checkpoint and apply it here prompt_style = Alpaca() prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) t0 = time.perf_counter() y = generate( model, encoded, max_returned_tokens, max_seq_length=max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 output = tokenizer.decode(y) output = output.split("### Response:")[1] if "### Response:" in output else output output = output.strip() fabric.print(output) tokens_generated = y.size(0) - prompt_length rank_print( fabric, f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) if __name__ == "__main__": from jsonargparse import CLI CLI(setup) ================================================ FILE: extensions/xla/generate/base.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time from pathlib import Path from typing import Optional import lightning as L import torch import torch_xla.core.xla_model as xm from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies import XLAFSDPStrategy from litgpt import GPT, Config, Tokenizer from litgpt.model import Block from litgpt.utils import check_valid_checkpoint_dir, lazy_load # support running without installing as a package wd = Path(__file__).parents[3].resolve() sys.path.append(str(wd)) from xla.utils import rank_print # noqa: E402 # xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor @torch.no_grad() def generate( model: GPT, idx: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k: Optional[int] = None, eos_id: Optional[int] = None, ) -> torch.Tensor: """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. The implementation of this function is modified from A. Karpathy's nanoGPT. Args: model: The model to use. idx: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. eos_id: If specified, stop generating any more token once the token is triggered. """ T = idx.size(0) assert max_returned_tokens > T if model.max_seq_length < max_returned_tokens - 1: # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do # not support it to avoid negatively impacting the overall speed raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") device, dtype = idx.device, idx.dtype # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(max_returned_tokens, dtype=dtype, device=device) empty[:T] = idx idx = empty # TODO: FSDP has an internal broadcasting issue, so we are forced to have this be of length 1 until it's fixed input_pos = torch.tensor([0], device=device) xm.mark_step() # generate up to a fixed number of tokens for _ in range(max_returned_tokens): x = idx.index_select(0, input_pos).view(1, -1) # forward logits = model(x, input_pos) logits = logits[0, -1] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits = torch.where(logits < v[[-1]], -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) # advance input_pos = input_pos[-1:] + 1 xm.mark_step() # concatenate the new generation idx = idx.index_copy(0, input_pos, idx_next) # if token is triggered, return the output (stop generation) if idx_next == eos_id: return idx[:input_pos] # include the EOS token return idx def setup( prompt: str = "What food do llamas eat?", *, num_samples: int = 1, max_new_tokens: int = 100, top_k: Optional[int] = 50, temperature: float = 0.8, checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"), precision: str = "bf16-true", ) -> None: """Generates text samples based on a pre-trained model and tokenizer. Args: prompt: The prompt string to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. checkpoint_dir: The checkpoint directory to load. precision: Indicates the Fabric precision setting to use. """ devices = XLAAccelerator.auto_device_count() strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else "auto" fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) fabric.launch(main, prompt, num_samples, max_new_tokens, top_k, temperature, checkpoint_dir) def main( fabric: L.Fabric, prompt: str, num_samples: int, max_new_tokens: int, top_k: Optional[int], temperature: float, checkpoint_dir: Path, ) -> None: check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) rank_print(fabric, f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) t0 = time.perf_counter() checkpoint = lazy_load(checkpoint_path) model.load_state_dict(checkpoint.get("model", checkpoint)) rank_print(fabric, f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) model.eval() model = fabric.setup_module(model) tokenizer = Tokenizer(checkpoint_dir) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens L.seed_everything(1234) for i in range(num_samples): with fabric.init_tensor(): # enable the kv cache model.set_kv_cache(batch_size=1) t0 = time.perf_counter() y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k) t = time.perf_counter() - t0 fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length rank_print( fabric, f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr, ) if __name__ == "__main__": from jsonargparse import CLI CLI(setup) ================================================ FILE: extensions/xla/scripts/__init__ ================================================ ================================================ FILE: extensions/xla/scripts/prepare_alpaca.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation derived from https://github.com/tloen/alpaca-lora""" import json from pathlib import Path from typing import Optional import torch import yaml from lightning_utilities.core.imports import RequirementCache from torch.utils.data import random_split from tqdm import tqdm from litgpt.tokenizer import Tokenizer from litgpt.utils import CLI def prepare( destination_path: Path = Path("data/alpaca"), checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), val_split_fraction: float = 0.03865, # to get exactly 2000 validation samples, seed: int = 42, mask_inputs: bool = False, # as in alpaca-lora data_file_name: str = "alpaca_data_cleaned_archive.json", data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", ignore_index: int = -100, max_seq_length: Optional[int] = None, ) -> None: """Prepare the Alpaca dataset for instruction tuning. The output is a training and test dataset saved as `train.pt` and `test.pt`, which stores the preprocessed and tokenized prompts and labels. """ if max_seq_length is None: with open(checkpoint_dir / "model_config.yaml", encoding="utf-8") as file: config = yaml.safe_load(file) max_seq_length = config["block_size"] destination_path.mkdir(parents=True, exist_ok=True) data_file_path = destination_path / data_file_name print("Loading data file...") download_if_missing(data_file_path, data_file_url) with open(data_file_path, encoding="utf-8") as file: data = json.load(file) print("Loading tokenizer...") tokenizer = Tokenizer(checkpoint_dir) # Partition the dataset into train and test train_set, test_set = random_split( data, [1.0 - val_split_fraction, val_split_fraction], generator=torch.Generator().manual_seed(seed) ) train_set, test_set = list(train_set), list(test_set) print(f"train has {len(train_set):,} samples") print(f"test has {len(test_set):,} samples") print("Processing train split ...") train_set = [ prepare_sample( example=sample, tokenizer=tokenizer, max_length=max_seq_length, mask_inputs=mask_inputs, ignore_index=ignore_index, ) for sample in tqdm(train_set) ] torch.save(train_set, destination_path / "train.pt") print("Processing test split ...") test_set = [ prepare_sample( example=sample, tokenizer=tokenizer, max_length=max_seq_length, mask_inputs=mask_inputs, ignore_index=ignore_index, ) for sample in tqdm(test_set) ] torch.save(test_set, destination_path / "test.pt") def download_if_missing(file_path: Path, file_url: str) -> None: """Downloads the raw json data file and saves it in the given destination.""" if file_path.exists() and file_path.stat().st_size > 0: return requests_available = RequirementCache("requests") if not requests_available: raise ModuleNotFoundError(str(requests_available)) import requests with open(file_path, "w", encoding="utf-8") as f: f.write(requests.get(file_url).text) def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict: """Processes a single sample. Each sample in the dataset consists of: - instruction: A string describing the task - input: A string holding a special input value for the instruction. This only applies to some samples, and in others this is empty. - output: The response string This function processes this data to produce a prompt text and a label for supervised training. The prompt text is formed as a single message including both the instruction and the input. The label/target is the same message but with the response attached. Finally, both the prompt and the label get tokenized. If desired, all tokens in the label that correspond to the original input prompt get masked out (default). """ full_prompt = generate_prompt(example) full_prompt_and_response = full_prompt + example["output"] encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length) # The labels are the full prompt with response, but with the prompt masked out labels = encoded_full_prompt_and_response.clone() if mask_inputs: labels[: len(encoded_full_prompt)] = ignore_index return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels} def generate_prompt(example: dict) -> str: """Generates a standardized message to prompt the model with an instruction, optional input and a 'response' field.""" if example["input"]: return ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" ) return ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" f"### Instruction:\n{example['instruction']}\n\n### Response:" ) if __name__ == "__main__": CLI(prepare) ================================================ FILE: extensions/xla/utils.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import itertools from functools import partial from pathlib import Path from typing import Any, Callable import lightning as L import torch from lightning.fabric.strategies.xla_fsdp import XLAFSDPStrategy, _activation_checkpointing_auto_wrapper from lightning_utilities.core.rank_zero import rank_prefixed_message from litgpt import GPT def rank_print(fabric: L.Fabric, message: object, *, flush: bool = True, **kwargs: Any) -> None: if fabric.local_rank == 0: message = str(message) # let each host print, but only on rank 0 message = rank_prefixed_message(message, fabric.global_rank) # TPU VM will only print when the script finishes if `flush=False` print(message, flush=flush, **kwargs) def materialize_parameters(module: torch.nn.Module, device: torch.device) -> None: for module_name, module in module.named_modules(): if any( param.is_meta for param in itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False)) ): module.to_empty(device=device, recurse=False) module.reset_parameters() def sequential_load_and_fsdp_wrap( fabric: L.Fabric, get_model: Callable[[], GPT], checkpoint_path: Path ) -> torch.nn.Module: assert fabric._launched # similar logic could be implemented for regular FSDP, but this implementation is specific to XLAFSDP assert isinstance(fabric.strategy, XLAFSDPStrategy) with fabric.init_module(empty_init=False), torch.device("meta"): model = get_model() # TODO: this could be made faster by broadcasting in separate process groups for each host if fabric.local_rank == 0: # load the full checkpoint on a single rank to limit the system memory usage state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=False) # mmap=True hangs else: # XLA cannot broadcast different number of tensors or different shapes in each rank. To get around this # limitation, we need to load the checkpoint on meta device to get the correct number of tensors and materialize # them as necessary state_dict = torch.load(checkpoint_path, map_location="meta", mmap=False) fsdp_kwargs = fabric.strategy._parse_fsdp_kwargs() if "auto_wrapper_callable" in fsdp_kwargs: # includes activation checkpointing if configured wrap = fsdp_kwargs.pop("auto_wrapper_callable") else: wrap = partial(_activation_checkpointing_auto_wrapper, set()) fsdp_kwargs.pop("auto_wrap_policy", None) # this needs to be removed or else root wrapping would error for i, block in enumerate(model.transformer.h): rank_print(fabric, f"Broadcasting transformer block {i}") # get the relevant piece of the state dict to_load = {} for param_name, _ in block.named_parameters(): if (key := f"transformer.h.{i}.{param_name}") not in state_dict: continue param = state_dict.pop(key) if not param.is_meta: to_load[param_name] = param else: # materialize this parameter for broadcast to work to_load[param_name] = torch.empty_like(param, device="cpu") to_load = fabric.broadcast(to_load) rank_print(fabric, f"Loading transformer block {i}") keys = block.load_state_dict(to_load, strict=False, assign=True) assert not keys.unexpected_keys # materialize any leftover meta parameters, regular FSDP does it automatically materialize_parameters(block, torch.device("cpu")) # init on CPU, FSDP will shard and move it # XLA FSDP only supports fp32 parameters. If the checkpoint had a different dtype, this needs to be converted # since we are loading with assign=True block = block.to(torch.float32) # shard the block rank_print(fabric, f"Wrapping transformer block {i}") wrapped_block = wrap(block, **fsdp_kwargs) model.transformer.h[i] = wrapped_block # load the rest of the state_dict, this assumes that all keys need to be loaded # an alternative technique would be to do load the rest of the state dict at once, but we want to materialize # and move the params to the xla device to reduce the system memory usage for key in list(state_dict): rank_print(fabric, f"Loading {key}") param = state_dict.pop(key) if param.is_meta: # materialize this parameter for broadcast to work param = torch.empty_like(param, device="cpu") param = fabric.broadcast(param) param = param.to(device=fabric.device, dtype=torch.float32) keys = model.load_state_dict({key: param}, strict=False, assign=True) assert not keys.unexpected_keys assert not state_dict # materialize any leftover meta parameters, regular FSDP does it automatically rank_print(fabric, "Materializing leftover parameters") materialize_parameters(model, fabric.device) return model ================================================ FILE: litgpt/__init__.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import logging import re from litgpt.api import LLM from litgpt.config import Config from litgpt.model import GPT # needs to be imported before config from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 pattern = re.compile(".*Profiler function .* will be ignored") logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage())) # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True __all__ = ["LLM", "GPT", "Config", "PromptStyle", "Tokenizer"] ================================================ FILE: litgpt/__main__.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import warnings import torch from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options from litgpt.chat.base import main as chat_fn from litgpt.deploy.serve import run_server as serve_fn from litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn from litgpt.finetune.adapter import setup as finetune_adapter_fn from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn from litgpt.finetune.full import setup as finetune_full_fn from litgpt.finetune.lora import setup as finetune_lora_fn from litgpt.generate.adapter import main as generate_adapter_fn from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn from litgpt.generate.base import main as generate_base_fn from litgpt.generate.full import main as generate_full_fn from litgpt.generate.sequentially import main as generate_sequentially_fn from litgpt.generate.speculative_decoding import main as generate_speculatively_fn from litgpt.generate.tp import main as generate_tp_fn from litgpt.parser_config import parser_commands from litgpt.pretrain import setup as pretrain_fn from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint as convert_hf_checkpoint_fn from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint as convert_lit_checkpoint_fn from litgpt.scripts.convert_pretrained_checkpoint import ( convert_pretrained_checkpoint as convert_pretrained_checkpoint_fn, ) from litgpt.scripts.download import download_from_hub as download_fn from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn PARSER_DATA = { "download": download_fn, "chat": chat_fn, "finetune": finetune_lora_fn, "finetune_lora": finetune_lora_fn, "finetune_full": finetune_full_fn, "finetune_adapter": finetune_adapter_fn, "finetune_adapter_v2": finetune_adapter_v2_fn, "pretrain": pretrain_fn, "generate": generate_base_fn, "generate_full": generate_full_fn, "generate_adapter": generate_adapter_fn, "generate_adapter_v2": generate_adapter_v2_fn, "generate_sequentially": generate_sequentially_fn, "generate_speculatively": generate_speculatively_fn, "generate_tp": generate_tp_fn, "convert_to_litgpt": convert_hf_checkpoint_fn, "convert_from_litgpt": convert_lit_checkpoint_fn, "convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn, "merge_lora": merge_lora_fn, "evaluate": evaluate_fn, "serve": serve_fn, } def _check_commands(): assert set(parser_commands()) == set(PARSER_DATA.keys()), ( "PARSER_DATA has to be kept in sync with litgpt.parser_config.parser_commands()" ) def main() -> None: _check_commands() set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) # PyTorch bug that raises a false-positive warning # More info: https://github.com/Lightning-AI/litgpt/issues/1561 warning_message = r"The epoch parameter in `scheduler.step\(\)` was not necessary and is being deprecated.*" warnings.filterwarnings( action="ignore", message=warning_message, category=UserWarning, module=r".*torch\.optim\.lr_scheduler.*" ) torch.set_float32_matmul_precision("high") CLI(PARSER_DATA) if __name__ == "__main__": main() ================================================ FILE: litgpt/adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation of the paper: LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention https://arxiv.org/abs/2303.16199 Port for LitGPT """ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from typing_extensions import Self from litgpt.config import Config as BaseConfig from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention @dataclass class Config(BaseConfig): adapter_prompt_length: int = 10 adapter_start_layer: int = 2 class GPT(BaseModel): # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) self.mask_cache: Optional[torch.Tensor] = None self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" super()._init_weights(module) if isinstance(module, CausalSelfAttention): module.reset_parameters() class Block(BaseBlock): def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.model.CausalSelfAttention` that adds the attention over the adaption prompt.""" def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) if block_idx >= config.adapter_start_layer: # adapter embedding layer self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) # gate for adaption self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) # kv cache for inference self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: y = super().scaled_dot_product_attention(q, k, v, mask) if self.block_idx < self.config.adapter_start_layer: return y aT = self.config.adapter_prompt_length if self.adapter_kv_cache is not None: # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av # are the same every call ak, av = self.adapter_kv_cache else: prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) aqkv = self.qkv(prefix) q_per_kv = self.config.n_head // self.config.n_query_groups aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) aqkv = aqkv.permute(0, 2, 3, 1, 4) _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) if self.config.n_query_groups != 1: # for MHA this is a no-op ak = ak.repeat_interleave(q_per_kv, dim=2) av = av.repeat_interleave(q_per_kv, dim=2) ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) self.adapter_kv_cache = (ak, av) T = q.size(2) amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) ay = super().scaled_dot_product_attention(q, ak, av, amask) return y + self.gating_factor * ay def reset_parameters(self) -> None: if hasattr(self, "gating_factor"): torch.nn.init.zeros_(self.gating_factor) def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with older checkpoints.""" if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: state_dict[key] = state_dict[key].permute(0, 2, 1, 3) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def mark_only_adapter_as_trainable(model: GPT) -> None: """Sets `requires_grad=False` for all non-adapter weights.""" for name, param in model.named_parameters(): param.requires_grad = adapter_filter(name, param) def adapter_filter(key: str, value: Any) -> bool: return "adapter_wte" in key or "gating_factor" in key ================================================ FILE: litgpt/adapter_v2.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation of the paper: LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model https://arxiv.org/abs/2304.15010 Port for LitGPT """ from dataclasses import dataclass from typing import Any, Dict, Optional, Type import torch import torch.nn as nn from typing_extensions import Self import litgpt from litgpt.adapter import GPT as BaseModel from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig from litgpt.model import Block as BaseBlock from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @dataclass class Config(BaseConfig): @property def mlp_class(self) -> Type: return getattr(litgpt.adapter_v2, self.mlp_class_name) def adapter_filter(key: str, value: Any) -> bool: adapter_substrings = ( # regular adapter v1 parameters "adapter_wte", "gating_factor", # adapter v2: new bias and scale used in Linear "adapter_scale", "adapter_bias", # adapter v2: Norm parameters are now trainable "norm_1", "norm_2", "ln_f", ) return any(s in key for s in adapter_substrings) class AdapterV2Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, **kwargs) -> None: super().__init__() self.linear = torch.nn.Linear(in_features, out_features, **kwargs) self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.adapter_scale * (self.linear(x) + self.adapter_bias) def reset_parameters(self) -> None: nn.init.zeros_(self.adapter_bias) nn.init.ones_(self.adapter_scale) class GPT(BaseModel): # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) self.mask_cache: Optional[torch.Tensor] = None self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" super()._init_weights(module) if isinstance(module, AdapterV2Linear): module.reset_parameters() def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class Block(BaseBlock): def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) self.mlp = config.mlp_class(config) class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" # Copy&paste from :class:`model.CausalSelfAttention` def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" mapping = { "qkv.weight": "qkv.linear.weight", "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) # For compatibility with older checkpoints if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: state_dict[key] = state_dict[key].permute(0, 2, 1, 3) for attr in ("weight", "bias"): legacy_key = f"{prefix}attn.linear.{attr}" current_key = f"{prefix}qkv.linear.{attr}" if legacy_key in state_dict: state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class GptNeoxMLP(litgpt.model.GptNeoxMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { "fc.weight": "fc.linear.weight", "fc.bias": "fc.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class LLaMAMLP(litgpt.model.LLaMAMLP): def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: nn.Module.__init__(self) self.intermediate_size = intermediate_size or config.intermediate_size self.fc_1 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias) self.fc_2 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias) self.proj = AdapterV2Linear(self.intermediate_size, config.n_embd, bias=config.bias) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { "fc_1.weight": "fc_1.linear.weight", "fc_1.bias": "fc_1.linear.bias", "fc_2.weight": "fc_2.linear.weight", "fc_2.bias": "fc_2.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) class LLaMAMoE(litgpt.model.LLaMAMoE): def __init__(self, config: Config) -> None: nn.Module.__init__(self) self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False) self.experts = nn.ModuleList( LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert) ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = {"gate.weight": "gate.linear.weight"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def mark_only_adapter_v2_as_trainable(model: GPT) -> None: """Sets requires_grad=False for all non-adapter weights""" for name, param in model.named_parameters(): param.requires_grad = adapter_filter(name, param) ================================================ FILE: litgpt/api.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. # # This file implements the LitGPT Python API import sys import time from pathlib import Path from typing import Any, Callable, List, Literal, Optional, Tuple, Union import lightning as L import numpy as np import torch from lightning.fabric.accelerators import CUDAAccelerator from lightning.fabric.plugins import BitsandbytesPrecision from tqdm import tqdm from litgpt.chat.base import generate as stream_generate_fn from litgpt.config import Config, name_to_config from litgpt.generate.base import generate as generate_fn from litgpt.generate.sequentially import sequential from litgpt.generate.tp import tensor_parallel from litgpt.model import GPT from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style, save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( auto_download_checkpoint, check_file_size_on_cpu_and_warn, check_nvlink_connectivity, chunked_cross_entropy, copy_config_files, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, save_config, ) class LLM(torch.nn.Module): def __init__( self, model: GPT, preprocessor=None, prompt_style: PromptStyle = None, devices: Union[int, List[int]] = None, config: Config = None, checkpoint_dir: Path = None, fabric: L.Fabric = None, generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None, kv_cache_initialized: bool = False, fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None, ) -> None: super().__init__() self.model = model self.preprocessor = preprocessor self.devices = devices self.prompt_style = prompt_style self.config = config self.checkpoint_dir = checkpoint_dir self.fabric = fabric self.generate_strategy = generate_strategy self.kv_cache_initialized = kv_cache_initialized self.fixed_kv_cache_size = fixed_kv_cache_size self.prev_generated_seq_length = 0 """ LLM model class for inference, pretraining, and finetuning. Example: from litgpt.api import LLM llm = LLM.load("microsoft/phi-2") text = llm.generate("What do Llamas eat?", top_k=1) print(text) """ @property def tokenizer(self): return self.preprocessor.tokenizer def state_dict(self, destination=None, prefix="", keep_vars=False): return self.model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) def load_state_dict(self, state_dict, strict=True): return self.model.load_state_dict(state_dict, strict=strict) def forward( self, input_ids: torch.Tensor, target_ids: Optional[torch.Tensor] = None, loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: logits = self.model(input_ids) if target_ids is not None: if loss_fn is None: loss_fn = chunked_cross_entropy loss = loss_fn(logits[..., :-1, :], target_ids[..., 1:]) return logits, loss else: return logits def trainer_setup(self, trainer_ckpt: Optional[Path] = None) -> None: """Initializes the model checkpoint for PyTorch Lightning Trainer contexts""" self.model = GPT(self.config) if trainer_ckpt is not None: # strip the object name key from the state_dict state_dict = torch.load(trainer_ckpt, weights_only=True)["state_dict"] first_key = next(iter(state_dict)) prefix = first_key.split(".")[0] + "." keys_to_modify = [key for key in state_dict if key.startswith(prefix)] for key in keys_to_modify: new_key = key.replace(prefix, "", 1) state_dict[new_key] = state_dict.pop(key) self.load_state_dict(state_dict, strict=True) elif self.checkpoint_dir is not None: state_dict = torch.load(self.checkpoint_dir / "lit_model.pth", weights_only=False) self.load_state_dict(state_dict, strict=False) else: raise ValueError( "No checkpoint found. Either provide a valid path via `trainer_ckpt` " "or ensure that `self.checkpoint_dir` points to a folder containing a `lit_model.pth` weight file." ) def save(self, out_dir: Optional[Path] = None, prompt_style: Optional[PromptStyle] = None) -> None: out_dir = Path(out_dir) save_path = out_dir / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) if prompt_style is None: prompt_style = PromptStyle.from_config(self.config) if self.fabric is None: torch.save(self.state_dict(), save_path) else: self.fabric.save(save_path, self.state_dict()) if self.fabric is None or self.fabric.global_rank == 0: # If initialization a model with random weights, the checkpoint dir can be none if self.checkpoint_dir is not None: copy_config_files(Path(self.checkpoint_dir), save_path.parent) else: save_config(self.config, out_dir) save_prompt_style(prompt_style, save_path.parent) @classmethod def load( cls, model: str, init: Optional[Literal["pretrained", "random"]] = "pretrained", tokenizer_dir: Optional[Path] = None, access_token: Optional[str] = None, distribute: Optional[Literal["auto"]] = "auto", ) -> "LLM": """ Loads the LLM from a local directory or model hub. Arguments model: A local path to a directory containing the model weights or a valid model name. You can get a list of valid model names via the `litgpt download list` command line argument. init: If "pretrained" (default), downloads the model from the HF Hub if a local model can't be found at the `model` directory name; otherwise loads the model from the local directory. If "random", initializes the `model` with random weights. tokenizer_dir: An optional tokenizer directory if `model` is not a checkpoint directory, or if a user wants to use a different tokenizer instead. access_token: Optional API token to access models with restrictions when using `init="pretrained"`. distribute: If "auto" (default), initializes the model on a single GPU if available and otherwise on the CPU. To have more control over the model distribution strategy and utilize multiple GPUs, you can set `llm = LLM.load(..., distribute=None)` and call `llm.distribute(...)` manually. """ allowed_init = {"pretrained", "random"} if init == "pretrained": checkpoint_dir = auto_download_checkpoint( model_name=model, access_token=access_token, ignore_tokenizer_files=tokenizer_dir is not None ) config = Config.from_file(checkpoint_dir / "model_config.yaml") elif init == "random": checkpoint_dir = None try: config = Config.from_name(model) except ValueError: print(f"Model name {model} is not supported.\n") available_models = "\n".join(sorted(name_to_config)) print(f"Available values:\n{available_models}") return else: raise ValueError(f"Invalid init option: {init}. Must be one of {allowed_init}") torch.set_float32_matmul_precision("high") if tokenizer_dir is not None: tokenizer_dir = extend_checkpoint_dir(Path(tokenizer_dir)) tokenizer = Tokenizer(tokenizer_dir) elif checkpoint_dir is not None: tokenizer = Tokenizer(checkpoint_dir) else: raise ValueError("Provide a path to a tokenizer directory via the `tokenizer_dir` setting.") if checkpoint_dir is not None: prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) else: prompt_style = PromptStyle.from_config(config) if distribute == "auto": if torch.cuda.is_available(): accelerator = "cuda" elif torch.backends.mps.is_available(): accelerator = "mps" else: accelerator = "cpu" fabric = L.Fabric( accelerator=accelerator, devices=1, precision=get_default_supported_precision(training=False), ) with fabric.init_module(empty_init=False): model = GPT(config) model.eval() preprocessor = Preprocessor(tokenizer, device=fabric.device) if checkpoint_dir is not None: checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) load_checkpoint(fabric, model, checkpoint_path) model = fabric.setup_module(model) else: preprocessor = Preprocessor(tokenizer, device="cuda" if torch.cuda.is_available() else "cpu") model = None fabric = None return cls( model=model, preprocessor=preprocessor, prompt_style=prompt_style, config=config, checkpoint_dir=checkpoint_dir, fabric=fabric, generate_strategy=None, kv_cache_initialized=False, fixed_kv_cache_size=False, ) def distribute( self, accelerator: Literal["cpu", "cuda", "auto"] = "auto", devices: Union[int, Literal["auto"]] = "auto", precision: Optional[Any] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None, fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None, ) -> None: """ Moves the model onto specified devices for single-GPU or multi-GPU inference accelerator: Which device type to load the model on ("cpu", "gpu", "mps", "cuda", or "auto") devices: The number of devices (1, 2, etc.) or "auto", which uses all available devices quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. For instance, "32-true", "16-mixed", "16-true", "bf16-mixed", "bf16-true". For more details, see https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision generate_strategy: Whether to use a sequential model generation strategy. The "sequential" settings allows running models that wouldn't fit in a single card by partitioning the transformer blocks across all devices and running them sequentially. Sequential generation may be slower but allows using larger models. Note that sequential generation sets `fixed_kv_cache_size="max_model_supported"`. You can set it to a lower integer value, `fixed_kv_cache_size=256` to reduce memory. The `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned via `llm.generate(...)`. fixed_kv_cache_size: If set to an integer value or "max_model_supported" is set, the kv-cache won't be resized dynamically during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy="sequential`. Note that the chosen `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned in `llm.generate(...)`. """ if self.checkpoint_dir is None: raise NotImplementedError( "The LLM was initialized with init='random' but .distribute() " "currently only supports pretrained weights." ) allowed_accelerators = {"cpu", "gpu", "cuda", "mps", "auto"} if accelerator not in allowed_accelerators: raise ValueError(f"Invalid accelerator: {accelerator}. Must be one of {allowed_accelerators}.") if accelerator == "auto": if torch.cuda.is_available(): accelerator = "cuda" elif torch.backends.mps.is_available(): accelerator = "mps" else: accelerator = "cpu" if generate_strategy in ("sequential", "tensor_parallel") and accelerator not in ("cuda", "gpu"): raise NotImplementedError( f"generate_strategy='{generate_strategy}' is only supported for accelerator='cuda'|'gpu'." ) if devices == "auto": if generate_strategy in ("sequential", "tensor_parallel"): total_devices = CUDAAccelerator.auto_device_count() else: total_devices = 1 elif isinstance(devices, int) and accelerator == "cuda": use_devices = calculate_number_of_devices(devices) total_devices = CUDAAccelerator.auto_device_count() if use_devices > total_devices: raise ValueError( f"You selected more devices ({use_devices}) than available in your system ({total_devices})." ) else: total_devices = use_devices if total_devices > 1 and generate_strategy not in ("sequential", "tensor_parallel"): raise NotImplementedError( "Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'." ) elif accelerator == "cpu" or accelerator == "mps": total_devices = 1 else: raise ValueError(f"devices argument must be an integer or 'auto', got {devices}") print(f"Using {total_devices} device(s)", file=sys.stderr) if precision is None: precision = get_default_supported_precision(training=False) print("Precision set", file=sys.stderr) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("The combination of quantization and mixed precision is not supported.") dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None # set "ddp" as the strategy for the launching functionality, but there's no data-parallelism if generate_strategy != "tensor_parallel": fabric = L.Fabric( accelerator=accelerator, devices=1, # Otherwise sequential wouldn't work, see litgpt/generate/sequentially.py # devices=devices, precision=precision, plugins=plugins, ) else: fabric = L.Fabric( accelerator=accelerator, devices=total_devices, strategy="ddp", precision=precision, plugins=plugins ) if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1: check_nvlink_connectivity(fabric) fabric.launch() print("Fabric launched", file=sys.stderr) self.kv_cache_initialized = False if generate_strategy is None: with fabric.init_module(empty_init=(total_devices > 1)): model = GPT(self.config) model.eval() if self.checkpoint_dir is not None: load_checkpoint(fabric, model, self.checkpoint_dir / "lit_model.pth") model = fabric.setup_module(model) if fixed_kv_cache_size is not None: if fixed_kv_cache_size is None or fixed_kv_cache_size == "max_model_supported": kv_cache_size = model.max_seq_length else: kv_cache_size = fixed_kv_cache_size model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device) self.kv_cache_initialized = True self.fixed_kv_cache_size = fixed_kv_cache_size elif generate_strategy in ("sequential", "tensor_parallel"): with fabric.init_tensor(), torch.device("meta"): model = GPT(self.config) model.eval() if generate_strategy == "sequential": state_dict = torch.load( str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False ) model.load_state_dict(state_dict, assign=True) model = fabric.setup_module(model, move_to_device=False) if fixed_kv_cache_size is None: fixed_kv_cache_size = "max_model_supported" if fixed_kv_cache_size == "max_model_supported": kv_cache_size = model.max_seq_length else: kv_cache_size = fixed_kv_cache_size model = sequential(model, fabric.device, kv_cache_size, total_devices) self.fixed_kv_cache_size = fixed_kv_cache_size elif generate_strategy == "tensor_parallel": if fabric.global_rank == 0: pbar = tqdm(total=fabric.world_size, desc="Loading model weights") for rank in range(fabric.world_size): if fabric.global_rank == rank: state_dict = torch.load( str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False, ) model.load_state_dict(state_dict, assign=True) # cannot use `.setup_module` because it will wrap with DDP model = fabric._precision.convert_module(model) model = tensor_parallel(fabric, model) with fabric.init_tensor(): if fixed_kv_cache_size is None: fixed_kv_cache_size = "max_model_supported" if fixed_kv_cache_size == "max_model_supported": kv_cache_size = model.max_seq_length else: kv_cache_size = fixed_kv_cache_size model.max_seq_length = kv_cache_size # the rope cache which is on meta device model.cos, model.sin = model.rope_cache() # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() model = fabric.to_device(model) fabric.barrier() if fabric.global_rank == 0: pbar.update(1) if fabric.global_rank == 0: pbar.close() self.kv_cache_initialized = True else: raise ValueError(f"Unsupported generate_strategy: {generate_strategy}") self.model = model self.fabric = fabric self.preprocessor.device = fabric.device @torch.inference_mode() def generate( self, prompt: str, sys_prompt: Optional[str] = None, max_new_tokens: int = 50, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, return_as_token_ids: bool = False, stream: bool = False, ) -> Union[str, torch.Tensor]: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. Arguments: model: The model to use. prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt string to use for generating the samples. The system prompt allows the user to provide additional instructions to shape all responses by providing additional context, behavioral guidelines, style, and constraints. max_new_tokens: The maximum number of new tokens to return. temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p return_as_token_ids: If True, returns the token IDs as a torch.Tensor. Otherwise, returns the decoded text as a string. stream: If True, returns a generator that yields tokens as they are generated. At the moment, this setting is slower and may use more memory than the non-streaming version. We plan to resolve this in the future. """ if self.model is None: raise AttributeError( "The model is not initialized yet; use the .distribute() " "or .trainer_setup() method to initialize the model." ) input_ids = self._text_to_token_ids(prompt, sys_prompt) prompt_length = input_ids.size(0) max_returned_tokens = prompt_length + max_new_tokens if not self.kv_cache_initialized: if self.fabric is not None: device = self.fabric.device else: device = self.preprocessor.device self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device) self.kv_cache_initialized = True # Dynamically grow the kv cache size if necessary if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens: tmp_device = self.model.mask_cache.device self.model.clear_kv_cache() self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device) else: for block in self.model.transformer.h: block.attn.kv_cache.reset_parameters() self.prev_generated_seq_length = max_returned_tokens self.model.eval() def iterator(): outputs = stream_generate_fn( model=self.model, prompt=input_ids, max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=([self.preprocessor.tokenizer.eos_id],), ) if return_as_token_ids: yield from outputs else: for output in outputs: yield self.preprocessor.decode(output) return if stream: outputs = iterator() else: outputs = generate_fn( model=self.model, prompt=input_ids, max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=self.preprocessor.tokenizer.eos_id, include_prompt=False, ) if stream: return outputs elif return_as_token_ids: return outputs else: return self.preprocessor.decode(outputs) def _text_to_token_ids(self, prompt: str, sys_prompt: Optional[str] = None) -> torch.Tensor: """Utility method to convert a prompt text to token IDs""" prompt = self.prompt_style.apply(prompt, sys_prompt=sys_prompt) input_ids = self.preprocessor.encode(prompt) return input_ids def benchmark(self, num_iterations=1, **kwargs): """ A wrapper around the .generate() method to calculate runtime performance. Arguments: num_iterations: How often the `.generate()` call is repeated. kwargs: Keyword arguments that are passed to the .generate() method. """ benchmark_dict = {} for i in range(num_iterations): time_to_first_token = None t0 = time.perf_counter() outputs = self.generate(**kwargs) if kwargs.get("stream", False): gen_outputs = [] for e in outputs: if time_to_first_token is None: t1 = time.perf_counter() time_to_first_token = t1 - t0 gen_outputs.append(e) outputs = "".join(gen_outputs) else: outputs = self.generate( **kwargs, ) benchmark_dict.setdefault("Seconds total", []).append(time.perf_counter() - t0) benchmark_dict.setdefault("Seconds to first token", []).append(time_to_first_token) tokens_generated = self.preprocessor.encode(outputs).size(0) benchmark_dict.setdefault("Tokens generated", []).append(tokens_generated) benchmark_dict.setdefault("Inference speed in tokens/sec", []).append( benchmark_dict["Tokens generated"][-1] / benchmark_dict["Seconds total"][-1] ) if self.fabric is not None and self.fabric.device.type == "cuda": benchmark_dict.setdefault("Total GPU memory allocated in GB", []).append( torch.cuda.max_memory_allocated() / 1e9 ) return outputs, benchmark_dict class Preprocessor: """ Preprocessor class for tokenization and de-tokenization. """ def __init__(self, tokenizer: Tokenizer, device: str = "cpu") -> None: self.tokenizer = tokenizer self.device = device def encode(self, text: str) -> torch.Tensor: return self.tokenizer.encode(text, device=self.device) def decode(self, token_ids: torch.Tensor) -> str: return self.tokenizer.decode(token_ids) def calculate_number_of_devices(devices): """ Utility function to calculate the number of devices. """ num_devices = devices if isinstance(devices, int) else len(devices) if isinstance(devices, list) else 0 return num_devices def benchmark_dict_to_markdown_table(data): """ Converts .benchmark() outputs to a markdown table """ markdown_table = ( "| Metric | Mean | Std Dev |\n" ) markdown_table += ( "|-------------------------------------|-----------------------------|-----------------------------|\n" ) for key, values in data.items(): mean_value = np.mean(values) std_dev_value = np.std(values, ddof=1) formatted_mean = f"{mean_value:.2f}" formatted_std_dev = f"{std_dev_value:.2f}" markdown_table += f"| {key.ljust(35)} | {formatted_mean.ljust(27)} | {formatted_std_dev.ljust(27)} |\n" return markdown_table def pull_request_benchmark_util(model_name="microsoft/phi-2", num_iterations=6): def print_table(header, data): print(f"\n### {header}\n") markdown_table = ( f"| Metric | First Iteration | " f"Iter 2-{num_iterations} Mean | Iter 2-{num_iterations} Standard Dev. |\n" f"|--------------------------------------|-----------------|" f"-------------------|-------------------------|\n" ) for key, value in data.items(): first_iteration = f"{value[0]:.2f}" if value[0] is not None else "N/A" clean_values = [v for v in value[1:] if v is not None] if clean_values: mean_value = np.mean(clean_values) std_dev_value = np.std(clean_values, ddof=1) mean_str = f"{mean_value:.2f}" std_dev_str = f"{std_dev_value:.2f}" else: mean_str = "N/A" std_dev_str = "N/A" markdown_table += f"| {key:<36} | {first_iteration:<15} | {mean_str:<17} | {std_dev_str:<23} |\n" print(markdown_table) import subprocess try: g_hash = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], capture_output=True, text=True, check=True ).stdout.strip() print(f"Git Commit Hash: {g_hash}") except subprocess.CalledProcessError: print("Git Commit Hash: N/A") print(f"PyTorch version: {torch.__version__}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}\n") # 1st table llm = LLM.load( model=model_name, ) text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1) print_table(f"Defaults ({model_name}), 1st time", bench_d) del llm # 2nd table llm = LLM.load( model=model_name, ) text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1) print_table(f"Defaults ({model_name}), 2nd time", bench_d) del llm # 3rd table llm = LLM.load( model=model_name, ) text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1, stream=True) print_table("stream=True", bench_d) del llm # 4th table llm = LLM.load(model=model_name, distribute=None) llm.distribute(fixed_kv_cache_size=500) text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1, stream=True) print_table("stream=True + fixed_kv_cache=500", bench_d) ================================================ FILE: litgpt/args.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import math import warnings from dataclasses import dataclass from typing import Dict, Optional, Union @dataclass class TrainArgs: """Training-related arguments""" save_interval: Optional[int] = 1000 """Number of optimizer steps between saving checkpoints""" log_interval: int = 1 """Number of iterations between logging calls""" global_batch_size: int = 64 """Number of samples between optimizer steps across data-parallel ranks""" micro_batch_size: int = 4 """Number of samples per data-parallel rank""" lr_warmup_steps: Optional[int] = 100 """Number of iterations with learning rate warmup active""" lr_warmup_fraction: Optional[float] = None """The fraction of an epoch to use for learning rate warmup""" epochs: Optional[int] = None """Number of epochs to train on""" # TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs? max_tokens: Optional[int] = None """Total number of tokens to train on""" max_steps: Optional[int] = None """Limits the number of optimizer steps to run""" max_time: Optional[float] = None """Limits the number of seconds to train for""" max_seq_length: Optional[int] = None """Limits the length of samples""" tie_embeddings: Optional[bool] = None """Whether to tie the embedding weights with the language modeling head weights""" # Optimization args max_norm: Optional[float] = None min_lr: float = 6e-5 def __post_init__(self) -> None: if self.lr_warmup_fraction and self.lr_warmup_steps: raise ValueError( "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one." ) if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1): raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.") if self.lr_warmup_steps and self.max_steps and (self.lr_warmup_steps >= self.max_steps): warnings.warn( "`--train.lr_warmup_steps` should be less than `--train.max_steps`." f" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.", UserWarning, ) def gradient_accumulation_iters(self, devices: int, num_nodes: int = 1) -> int: """Number of iterations between gradient synchronizations""" gradient_accumulation_iters = self.batch_size(devices, num_nodes) // self.micro_batch_size assert gradient_accumulation_iters > 0 return gradient_accumulation_iters def batch_size(self, devices: int, num_nodes: int = 1) -> int: """Number of samples between optimizer steps per data-parallel rank""" batch_size = self.global_batch_size // (devices * num_nodes) assert batch_size > 0 return batch_size def warmup_iters(self, devices: int, num_nodes: int, max_iters: int, train_dataloader) -> int: """Number of iterations to warm up the learning rate.""" if self.lr_warmup_fraction: return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader))) if self.lr_warmup_steps: return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices, num_nodes)) return 0 @dataclass class EvalArgs: """Evaluation-related arguments""" interval: int = 600 """Number of optimizer steps between evaluation calls""" max_new_tokens: Optional[int] = None """Number of tokens to generate""" max_iters: int = 100 """Number of iterations""" initial_validation: bool = False """Whether to evaluate on the validation set at the beginning of the training""" final_validation: bool = True """Whether to evaluate on the validation set at the end of the training""" evaluate_example: Union[str, int] = "first" """How to pick an example instruction to evaluate periodically during training. Can be "first", "random", or an integer index to pick a specific example.""" @dataclass class LogArgs: """Logging-related arguments. Different loggers use different fields.""" # === WandB Fields === project: Optional[str] = None """WandB project name""" run: Optional[str] = None """WandB run name (defaults to generated name)""" group: Optional[str] = None """WandB group name""" # === LitLogger Fields (Lightning.ai) === teamspace: Optional[str] = None """Teamspace name where charts and artifacts will appear""" metadata: Optional[Dict] = None """Extra metadata to associate with the experiment as tags""" log_model: bool = False """If True, automatically log model checkpoints as artifacts""" save_logs: bool = True """If True, capture and upload terminal logs""" checkpoint_name: Optional[str] = None """Override the base name for logged checkpoints""" ================================================ FILE: litgpt/chat/__init__.py ================================================ ================================================ FILE: litgpt/chat/base.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time from pathlib import Path from pprint import pprint from typing import Iterator, List, Literal, Optional, Tuple import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from litgpt.config import Config from litgpt.model import GPT from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.utils import ( auto_download_checkpoint, check_file_size_on_cpu_and_warn, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, ) @torch.inference_mode() def generate( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, stop_tokens: Tuple[List[int], ...] = (), ) -> Iterator[torch.Tensor]: """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as possible. Arguments: model: The model to use. prompt: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). temperature: Scales the predicted logits by 1 / temperature top_k: If specified, only sample among the tokens with the k highest probabilities. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p stop_tokens: If specified, stop generating any more token once one of this list is generated. """ from litgpt.generate.base import generate_fn return generate_fn( include_prompt=False, include_eos=False, model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, ) def process_prompt( prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens ): prompt = prompt_style.apply(prompt=prompt) encoded_prompt = tokenizer.encode(prompt, device=fabric.device) if max_new_tokens is None: max_returned_tokens = model.max_seq_length else: first_turn = model.mask_cache is None max_returned_tokens = encoded_prompt.size(0) + max_new_tokens if first_turn or max_returned_tokens > model.max_seq_length: model.max_seq_length = max_returned_tokens model.set_kv_cache(batch_size=1, device=fabric.device) y: Iterator[torch.Tensor] = generate( model, encoded_prompt, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, ) token_generator: Iterator[str] = tokenizer.decode_stream(y, device=fabric.device) fabric.print(">> Reply: ", end="") t0 = time.perf_counter() tokens_generated = 0 for tok in token_generator: tokens_generated += 1 fabric.print(tok, end="", flush=True) t = time.perf_counter() - t0 for block in model.transformer.h: block.attn.kv_cache.reset_parameters() fabric.print( f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens", file=sys.stderr, ) fabric.print() def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens): while True: try: if not multiline: prompt = input(">> Prompt: ") else: print(">> Prompt: (Type '!submit' on a new line to end input).") prompt_lines = [] while True: line = input() if line.strip().lower() in ("!submit", "!quit", "!exit"): break prompt_lines.append(line) prompt = "\n".join(prompt_lines) except KeyboardInterrupt: break prompt = prompt.strip() if not prompt or prompt.lower() in ("!quit", "!exit"): break process_prompt( prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens ) @torch.inference_mode() def main( checkpoint_dir: Path, *, max_new_tokens: int = 50, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, precision: Optional[str] = None, compile: bool = False, multiline: bool = False, access_token: Optional[str] = None, ) -> None: """Chat with a model. Args: checkpoint_dir: A local path to a directory containing the model weights or a valid model name. You can get a list of valid model names via the `litgpt download list` command line argument. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. compile: Whether to use compilation to speed up token generation. Will increase startup time. multiline: Whether to support multiline input prompts. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) # Merge if this is a raw LoRA checkpoint checkpoint_path = checkpoint_dir / "lit_model.pth" if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file(): print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.") merge_lora(checkpoint_dir) if not checkpoint_path.is_file(): checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) config = Config.from_file(checkpoint_dir / "model_config.yaml") with fabric.init_module(empty_init=True): model = GPT(config) if compile: print( "IMPORTANT: with enabled compilation the KV-cache size is determined by model's maximum context size, which leads to " "a higher memory consumption. In case of an OOM error, try to set `--compile=False`." ) model.set_kv_cache(batch_size=1) load_checkpoint(fabric, model, checkpoint_path) model.eval() if compile: torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True global next_token next_token = torch.compile(next_token, mode="reduce-overhead", dynamic=True) model = fabric.setup_module(model) tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) stop_tokens = prompt_style.stop_tokens(tokenizer) if multiline: exit_instruction = "To exit, enter '!quit' or '!exit' on an empty prompt and press 'Enter'." else: exit_instruction = "To exit, press 'Enter' on an empty prompt." print(f"Now chatting with {config.name}.\n{exit_instruction}\n") L.seed_everything(1234) interact( multiline=multiline, model=model, tokenizer=tokenizer, prompt_style=prompt_style, fabric=fabric, temperature=temperature, max_new_tokens=(None if compile else max_new_tokens), top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, ) if fabric.device.type == "cuda": fabric.print(f"\nMemory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/config.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path from typing import Any, List, Literal, Optional, Type, Union import yaml from typing_extensions import Self def find_multiple(n: int, k: int) -> int: """Utility function for finding the nearest value to n which is a multiple of k. NOTE: We define this function in this module rather than `litgpt.utils` so that users can import this file to do configuration manipulations in Python environments which do not include all the dependencies demanded by `litgpt.utils`. """ assert k > 0 if n % k == 0: return n return n + k - (n % k) @dataclass class Config: name: str = "" hf_config: dict = field(default_factory=dict) # General size parameters block_size: int = 4096 n_layer: int = 16 n_embd: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None # Transformer block (structure, normalizations) norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 norm_qk: bool = False norm_qk_type: Literal["default", "olmo2"] = "default" post_attention_norm: bool = False post_mlp_norm: bool = False parallel_residual: bool = True shared_attention_norm: bool = False # Transformer block (self-attention) n_head: int = 32 head_size: Optional[int] = None # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between # Example with `n_head=4` # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ │ │ │ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ # MHA GQA MQA # n_query_groups=4 n_query_groups=2 n_query_groups=1 # # credit https://arxiv.org/pdf/2305.13245.pdf n_query_groups: Optional[int] = None attn_bias: bool = False attention_scores_scalar: Optional[int] = None # If `sliding_window_size` is given, sliding window attention with this # size is used in layers where `sliding_window_indices` has a 1. The # default is all 1, so that sliding window attention is used in all # layers. If `len(sliding_window_indices) > n_layer`, we only use the # initial part. sliding_window_size: Optional[int] = None sliding_window_indices: Optional[List[int]] = None # if `attention_logit_softcapping` is used, cannot use optimized # `torch.nn.functional.scaled_dot_product_attention` (which implements # Flash attention), may result in higher memory and runtime footprint. attention_logit_softcapping: Optional[float] = None # Rotary position embedding (RoPE) rope_base: int = 10000 rotary_percentage: float = 0.25 rope_condense_ratio: int = 1 rope_adjustments: Optional[dict] = None rope_interleave: bool = False # Transformer block (MLP) intermediate_size: Optional[int] = None moe_intermediate_size: Optional[int] = None bias: bool = True mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" n_expert: int = 0 n_shared_expert: Optional[int] = None n_expert_groups: Optional[int] = None n_topk_groups: Optional[int] = None n_topk_scores_per_group: Optional[int] = None n_expert_per_token: int = 0 first_k_dense_replace: Optional[int] = None routed_scaling_factor: float = 1.0 norm_topk_prob: bool = False # GPT before/after blocks scale_embeddings: bool = False lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None norm_1: bool = True norm_2: bool = True latent_attention: Optional[dict] = None # The base period of the RoPE embeddings for local attention. # If not provided, `rope_base` will be used for both local and global attention. rope_local_base_freq: Optional[float] = None # If provided, must have `>= n_layer` entries, either 0 or 1. For 0, # `rope_base` is used, for 1 `rope_local_base_freq` is used. If # `len(rope_indices) > n_layer`, we only use the initial part. rope_indices: Optional[List[int]] = None def __post_init__(self): if not self.name: self.name = self.hf_config.get("name", self.name) if self.head_size is None: assert self.n_embd % self.n_head == 0 self.head_size = self.n_embd // self.n_head # vocab size should be a power of 2 to be optimal on hardware. compute the closest value if self.padded_vocab_size is None: self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) else: # vocab size shouldn't be larger than padded vocab size self.vocab_size = min(self.vocab_size, self.padded_vocab_size) # compute the number of query groups if self.n_query_groups is not None: assert self.n_head % self.n_query_groups == 0 else: self.n_query_groups = self.n_head # compute the intermediate size for MLP if not set if self.intermediate_size is None: if self.mlp_class_name == "LLaMAMLP": raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`") self.intermediate_size = 4 * self.n_embd self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.sliding_window_size is not None: self.sliding_window_indices = check_indicator_and_length( self.sliding_window_indices, name="sliding_window_indices", required_length=self.n_layer, ) if self.rope_local_base_freq is not None: self.rope_indices = check_indicator_and_length( self.rope_indices, name="rope_indices", required_length=self.n_layer, ) if self.latent_attention is not None: self.q_lora_rank = self.latent_attention.get("q_lora_rank") self.kv_lora_rank = self.latent_attention.get("kv_lora_rank") self.qk_rope_head_dim = self.latent_attention.get("qk_rope_head_dim") self.qk_nope_head_dim = self.latent_attention.get("qk_nope_head_dim") self.v_head_dim = self.latent_attention.get("v_head_dim") assert ( self.q_lora_rank and self.kv_lora_rank and self.qk_rope_head_dim and self.qk_nope_head_dim and self.v_head_dim ) is not None assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA" self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim self.rope_n_elem = self.qk_rope_head_dim if self.first_k_dense_replace is not None: assert self.mlp_class_name == "LLaMAMoE" if self.n_expert_groups is not None: assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1 assert self.n_topk_groups is not None experts_per_group = self.n_expert // self.n_expert_groups assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: if name not in name_to_config: # search through all `config['hf_config']['name']` try: conf_dict = next( config for config in configs if name == config["hf_config"]["name"] or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name ) except StopIteration: raise ValueError(f"{name!r} is not a supported config name") else: conf_dict = name_to_config[name] conf_dict = conf_dict.copy() conf_dict.update(kwargs) return cls(**conf_dict) @classmethod def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: with open(path, encoding="utf-8") as fp: file_kwargs = yaml.safe_load(fp) if file_kwargs is None: raise ValueError(f"{path} is empty which is likely unexpected.") file_kwargs.update(kwargs) return cls(**file_kwargs) @classmethod def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.""" if (config_path := path / "model_config.yaml").is_file(): return cls.from_file(config_path, **kwargs) if (model_name := path.name) in name_to_config: return cls.from_name(model_name, **kwargs) raise FileNotFoundError(f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists.") @property def mlp_class(self) -> Type: # `self.mlp_class_name` cannot be the type to keep the config serializable import litgpt.model return getattr(litgpt.model, self.mlp_class_name) @property def norm_class(self) -> Type: # `self.norm_class_name` cannot be the type to keep the config serializable from functools import partial import torch # Torch import is lazy to make config loading faster if self.norm_class_name == "RMSNorm": from litgpt.model import RMSNorm return partial(RMSNorm, add_unit_offset="Gemma" in self.name) if self.norm_class_name == "LayerNorm" and "OLMo" in self.name: # this makes it equivalent to `torch.nn.functional.layer_norm` # that is used by OLMo # Table 5 caption in the OLMo paper shows this - https://aclanthology.org/2024.acl-long.841 return partial(torch.nn.LayerNorm, elementwise_affine=False) return getattr(torch.nn, self.norm_class_name) def check_indicator_and_length( params: Optional[List[int]], name: str, required_length: int, use_initial_part: bool = True, def_val: int = 1, ) -> List[int]: if params is None: return [def_val] * required_length if len(params) != required_length: if use_initial_part and len(params) > required_length: params = params[:required_length] else: raise ValueError(f"{name} = {params}, must have length {required_length}") if not set(params).issubset({0, 1}): raise ValueError(f"{name} = {params}, must only contain 0 and 1") return params ######################## # Stability AI StableLM ######################## configs = [ # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json dict(name="stablelm-base-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b")), # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json dict( name="stablelm-base-alpha-7b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"), n_head=48, n_embd=6144, padding_multiple=256, ), # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json dict(name="stablelm-tuned-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), n_head=32), # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json dict( name="stablelm-tuned-alpha-7b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"), n_head=48, n_embd=6144, padding_multiple=256, ), # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json dict( name="stablelm-3b-4e1t", hf_config=dict(org="stabilityai", name="stablelm-3b-4e1t"), padded_vocab_size=50304, n_layer=32, n_head=32, n_embd=2560, parallel_residual=False, bias=False, mlp_class_name="LLaMAMLP", intermediate_size=6912, ), # https://huggingface.co/stabilityai/stablelm-zephyr-3b/blob/main/config.json dict( name="stablelm-zephyr-3b", hf_config=dict(org="stabilityai", name="stablelm-zephyr-3b"), padded_vocab_size=50304, n_layer=32, n_head=32, n_embd=2560, parallel_residual=False, bias=False, mlp_class_name="LLaMAMLP", intermediate_size=6912, ), ] ########################## # Stability AI StableCode ########################## stablecode = [ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json dict( name="stablecode-completion-alpha-3b", hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"), block_size=16384, vocab_size=49152, n_layer=32, n_embd=2560, ), # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json dict( name="stablecode-completion-alpha-3b-4k", hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"), vocab_size=49152, n_layer=32, n_embd=2560, ), # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json dict( name="stablecode-instruct-alpha-3b", hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"), vocab_size=49152, n_layer=32, n_embd=2560, ), # https://huggingface.co/stabilityai/stable-code-3b/blob/main/config.json dict( name="stable-code-3b", hf_config=dict(org="stabilityai", name="stable-code-3b"), padded_vocab_size=50304, n_layer=32, n_embd=2560, block_size=16384, parallel_residual=False, bias=False, mlp_class_name="LLaMAMLP", intermediate_size=6912, ), ] configs.extend(stablecode) #################### # EleutherAI Pythia #################### pythia = [ # https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json dict( name="pythia-14m", hf_config=dict(org="EleutherAI", name="pythia-14m"), block_size=512, n_layer=6, n_embd=128, n_head=4, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json dict( name="pythia-31m", hf_config=dict(org="EleutherAI", name="pythia-31m"), block_size=1024, n_layer=6, n_embd=256, n_head=8, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json dict( name="pythia-70m", hf_config=dict(org="EleutherAI", name="pythia-70m"), block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json dict( name="pythia-160m", hf_config=dict(org="EleutherAI", name="pythia-160m"), block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json dict( name="pythia-410m", hf_config=dict(org="EleutherAI", name="pythia-410m"), block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json dict( name="pythia-1b", hf_config=dict(org="EleutherAI", name="pythia-1b"), block_size=2048, n_embd=2048, n_head=8, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json dict( name="pythia-1.4b", hf_config=dict(org="EleutherAI", name="pythia-1.4b"), block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json dict( name="pythia-2.8b", hf_config=dict(org="EleutherAI", name="pythia-2.8b"), block_size=2048, n_layer=32, n_embd=2560, padding_multiple=128, ), # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json dict( name="pythia-6.9b", hf_config=dict(org="EleutherAI", name="pythia-6.9b"), block_size=2048, n_layer=32, padding_multiple=256, ), # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json dict( name="pythia-12b", hf_config=dict(org="EleutherAI", name="pythia-12b"), block_size=2048, n_layer=36, n_embd=5120, n_head=40, ), ] configs.extend(pythia) for c in pythia: # "pythia-14m" and "pythia-31m" don't have deduped version if c["name"] in ("pythia-14m", "pythia-31m"): continue copy = deepcopy(c) copy["name"] = f"{c['name']}-deduped" copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped" configs.append(copy) ################# # TII UAE Falcon ################# falcon = [ # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json dict( name="falcon-7b{}", hf_config=dict(org="tiiuae", name="falcon-7b{}"), block_size=2048, vocab_size=65024, padded_vocab_size=65024, n_layer=32, n_head=71, n_embd=4544, rotary_percentage=1.0, n_query_groups=1, bias=False, # this is not in the config, but in the original model implementation, only for this config shared_attention_norm=True, ), # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json dict( name="falcon-40b{}", hf_config=dict(org="tiiuae", name="falcon-40b{}"), block_size=2048, vocab_size=65024, padded_vocab_size=65024, n_layer=60, n_head=128, n_embd=8192, rotary_percentage=1.0, n_query_groups=8, bias=False, ), ] for c in falcon: for kind in ("", "-instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json falcon180b = dict( name="falcon-180B{}", hf_config=dict(org="tiiuae", name="falcon-180B{}"), block_size=2048, vocab_size=65024, padded_vocab_size=65024, n_layer=80, n_head=232, n_embd=14848, rotary_percentage=1.0, n_query_groups=8, bias=False, ) for kind in ("", "-chat"): copy = deepcopy(falcon180b) copy["name"] = falcon180b["name"].format(kind) copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind) configs.append(copy) falcon3 = [ # https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json dict( name="Falcon3-1B{}", hf_config=dict(org="tiiuae", name="Falcon3-1B{}"), block_size=4096, vocab_size=131072, padded_vocab_size=131072, n_layer=18, n_head=8, n_query_groups=4, n_embd=2048, rotary_percentage=1.0, parallel_residual=False, rope_base=1000042, norm_eps=1e-6, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8192, ), # https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json dict( name="Falcon3-3B{}", hf_config=dict(org="tiiuae", name="Falcon3-3B{}"), block_size=32768, vocab_size=131072, padded_vocab_size=131072, n_layer=22, n_head=12, n_query_groups=4, n_embd=3072, rotary_percentage=1.0, parallel_residual=False, rope_base=1000042, norm_eps=1e-6, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=9216, ), # https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json dict( name="Falcon3-7B{}", hf_config=dict(org="tiiuae", name="Falcon3-7B{}"), block_size=32768, vocab_size=131072, padded_vocab_size=131072, n_layer=28, n_head=12, n_query_groups=4, n_embd=3072, rotary_percentage=1.0, parallel_residual=False, rope_base=1000042, norm_eps=1e-6, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=23040, ), # https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json dict( name="Falcon3-10B{}", hf_config=dict(org="tiiuae", name="Falcon3-10B{}"), block_size=32768, vocab_size=131072, padded_vocab_size=131072, n_layer=40, n_head=12, n_query_groups=4, n_embd=3072, rotary_percentage=1.0, parallel_residual=False, rope_base=1000042, norm_eps=1e-6, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=23040, ), ] for c in falcon3: for kind in ("-Base", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############################# # OpenLM Research Open LLaMA ############################# open_LLaMA = [ # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json dict( name="open_llama_3b", hf_config=dict(org="openlm-research", name="open_llama_3b"), block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=26, n_embd=3200, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-6, mlp_class_name="LLaMAMLP", intermediate_size=8640, ), # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json dict( name="open_llama_7b", hf_config=dict(org="openlm-research", name="open_llama_7b"), block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-6, mlp_class_name="LLaMAMLP", intermediate_size=11008, ), # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json dict( name="open_llama_13b", hf_config=dict(org="openlm-research", name="open_llama_13b"), block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-6, mlp_class_name="LLaMAMLP", intermediate_size=13824, ), ] configs.extend(open_LLaMA) ############### # Meta LLaMA 2 ############### llama_2 = [ # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json dict( name="Llama-2-7b{}-hf", hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"), vocab_size=32000, padding_multiple=64, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, ), # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json dict( name="Llama-2-13b{}-hf", hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"), vocab_size=32000, padding_multiple=64, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, ), # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json dict( name="Llama-2-70b{}-hf", hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"), vocab_size=32000, padding_multiple=64, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, ), ] for c in llama_2: for kind in ("", "-chat"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############### # Meta LLaMA 3 ############### llama_3 = [ # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json dict( name="Llama-3-8B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"), block_size=8192, vocab_size=128000, padded_vocab_size=128256, n_layer=32, n_head=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=14336, rope_base=500000, ), # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/config.json dict( name="Llama-3.1-8B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-8B{}"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=32, n_head=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=14336, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json dict( name="Llama-3-70B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"), block_size=8192, vocab_size=128000, padded_vocab_size=128256, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, ), # https://huggingface.co/meta-llama/Meta-Llama-3.1-70B/blob/main/config.json dict( name="Llama-3.1-70B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-70B{}"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json dict( name="Llama-3.1-405B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-405B{}"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=126, n_head=128, n_embd=16384, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=53248, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json dict( name="Llama-3.2-1B{}", hf_config=dict(org="meta-llama", name="Llama-3.2-1B{}"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=16, n_embd=2048, n_head=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8192, rope_base=500000, rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json dict( name="Llama-3.2-3B{}", hf_config=dict(org="meta-llama", name="Llama-3.2-3B{}"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=28, n_embd=3072, n_head=24, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8192, rope_base=500000, rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json dict( name="Llama-3.3-70B-Instruct", hf_config=dict(org="meta-llama", name="Llama-3.3-70B-Instruct"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), ] for c in llama_3: if c["name"] == "Llama-3.3-70B-Instruct": configs.append(c) continue for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ######################### # NVIDIA Llama Nemotron ######################### configs.append( dict( name="Llama-3.1-Nemotron-70B-Instruct-HF", hf_config=dict(org="nvidia", name="Llama-3.1-Nemotron-70B-Instruct-HF"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), ) ################# # Allen AI OLMo ################# olmo = [ # https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json dict( name="OLMo-1B-hf", hf_config=dict(org="allenai", name="OLMo-1B-hf"), vocab_size=50280, padded_vocab_size=50304, block_size=2048, n_embd=2048, n_layer=16, n_head=16, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="LayerNorm", mlp_class_name="LLaMAMLP", intermediate_size=8192, ), # https://huggingface.co/allenai/OLMo-7B-hf/blob/main/config.json dict( name="OLMo-7B-hf", hf_config=dict(org="allenai", name="OLMo-7B-hf"), vocab_size=50280, padded_vocab_size=50304, block_size=2048, n_layer=32, n_head=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="LayerNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, ), # https://huggingface.co/allenai/OLMo-7B-Instruct-hf/blob/main/config.json dict( name="OLMo-7B-Instruct-hf", hf_config=dict(org="allenai", name="OLMo-7B-Instruct-hf"), vocab_size=50280, padded_vocab_size=50304, block_size=2048, n_layer=32, n_head=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="LayerNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, ), ] configs.extend(olmo) olmo2 = [ # https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json dict( name="OLMo-2-1124-7B{}", hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"), vocab_size=100278, padded_vocab_size=100352, block_size=4096, n_embd=4096, n_layer=32, n_head=32, n_query_groups=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", norm_eps=1e-06, intermediate_size=11008, rope_base=500000, norm_qk=True, post_mlp_norm=True, norm_1=False, norm_2=False, norm_qk_type="olmo2", post_attention_norm=True, ), # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json dict( name="OLMo-2-1124-13B{}", hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"), vocab_size=100278, padded_vocab_size=100352, block_size=4096, n_embd=5120, n_layer=40, n_head=40, n_query_groups=40, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", norm_eps=1e-06, intermediate_size=13824, rope_base=500000, norm_qk=True, post_mlp_norm=True, norm_1=False, norm_2=False, norm_qk_type="olmo2", post_attention_norm=True, ), ] for c in olmo2: for kind in ("", "-SFT", "-DPO", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############### # Google Gemma ############### gemma = [ # https://huggingface.co/google/gemma-2b/blob/main/config.json dict( name="Gemma-2b", hf_config=dict(org="google", name="gemma-2b"), scale_embeddings=True, vocab_size=256000, padding_multiple=64, n_embd=2048, n_layer=18, n_head=8, n_query_groups=1, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", intermediate_size=16384, ), # https://huggingface.co/google/gemma-7b/blob/main/config.json dict( name="Gemma-7b", hf_config=dict(org="google", name="gemma-7b"), scale_embeddings=True, vocab_size=256000, padding_multiple=64, n_embd=3072, n_layer=28, n_head=16, head_size=256, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", intermediate_size=24576, ), # https://huggingface.co/google/gemma-2-2b/blob/main/config.json dict( name="Gemma-2-2b", hf_config=dict(org="google", name="gemma-2-2b"), scale_embeddings=True, attention_scores_scalar=256, vocab_size=256000, block_size=8192, sliding_window_size=4096, # only layer with idx 0, 2, 4, ... have sliding window attention sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(26)], intermediate_size=9216, n_embd=2304, n_layer=26, n_head=8, n_query_groups=4, head_size=256, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, attention_logit_softcapping=50.0, final_logit_softcapping=30.0, ), # https://huggingface.co/google/gemma-2-9b/blob/main/config.json dict( name="Gemma-2-9b", hf_config=dict(org="google", name="gemma-2-9b"), scale_embeddings=True, attention_scores_scalar=256, vocab_size=256000, block_size=8192, sliding_window_size=4096, # only layer with idx 0, 2, 4, ... have sliding window attention sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(42)], intermediate_size=14336, n_embd=3584, n_layer=42, n_head=16, n_query_groups=8, head_size=256, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, attention_logit_softcapping=50.0, final_logit_softcapping=30.0, ), # https://huggingface.co/google/gemma-2-27b/blob/main/config.json dict( name="Gemma-2-27b", hf_config=dict(org="google", name="gemma-2-27b"), scale_embeddings=True, # In Gemma 2 27B attention scores are scaled not by `sqrt(head_size)` (11.31), # but by `sqrt(n_emb // n_head)` = sqrt(4608 // 32) = 12 attention_scores_scalar=144, vocab_size=256000, block_size=8192, sliding_window_size=4096, # only layer with idx 0, 2, 4, ... have sliding window attention sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(46)], intermediate_size=36864, n_embd=4608, n_layer=46, n_head=32, n_query_groups=16, head_size=128, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, attention_logit_softcapping=50.0, final_logit_softcapping=30.0, ), ] configs.extend(gemma) for c in gemma: copy = deepcopy(c) copy["name"] = f"{c['name']}-it" copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it" configs.append(copy) ################## # Google Gemma 3 ################## gemma3 = [ # https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json dict( name="Gemma-3-1b-it", hf_config=dict(org="google", name="gemma-3-1b-it"), scale_embeddings=True, attention_scores_scalar=256, vocab_size=262144, block_size=131072, sliding_window_size=512, # 5 local layers for every global layer sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)], intermediate_size=6912, n_embd=1152, n_layer=26, n_head=4, n_query_groups=1, head_size=256, rotary_percentage=1.0, rope_adjustments=None, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, norm_qk=True, rope_base=1000000, rope_local_base_freq=10000, # 5 local layers for every global layer rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)], ), # https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json dict( name="Gemma-3-4b-it", hf_config=dict(org="google", name="gemma-3-4b-it"), scale_embeddings=True, attention_scores_scalar=256, vocab_size=262144, block_size=131072, sliding_window_size=1024, # 5 local layers for every global layer sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)], intermediate_size=10240, n_embd=2560, n_layer=34, n_head=8, n_query_groups=4, head_size=256, rotary_percentage=1.0, rope_adjustments=dict(factor=8.0), parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, norm_qk=True, rope_base=1000000, rope_local_base_freq=10000, # 5 local layers for every global layer rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)], ), # https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json dict( name="Gemma-3-12b-it", hf_config=dict(org="google", name="gemma-3-12b-it"), scale_embeddings=True, attention_scores_scalar=256, vocab_size=262144, block_size=131072, sliding_window_size=1024, # 5 local layers for every global layer sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)], intermediate_size=15360, n_embd=3840, n_layer=48, n_head=16, n_query_groups=8, head_size=256, rotary_percentage=1.0, rope_adjustments=dict(factor=8.0), parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, norm_qk=True, rope_base=1000000, rope_local_base_freq=10000, # 5 local layers for every global layer rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)], ), # https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json dict( name="Gemma-3-27b-it", hf_config=dict(org="google", name="gemma-3-27b-it"), scale_embeddings=True, attention_scores_scalar=168, vocab_size=262144, block_size=131072, sliding_window_size=1024, # 5 local layers for every global layer sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)], intermediate_size=21504, n_embd=5376, n_layer=62, n_head=32, n_query_groups=16, head_size=128, rotary_percentage=1.0, rope_adjustments=dict(factor=8.0), parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", post_attention_norm=True, post_mlp_norm=True, norm_qk=True, rope_base=1000000, rope_local_base_freq=10000, # 5 local layers for every global layer rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)], ), ] configs.extend(gemma3) ################## # Google CodeGemma ################## codegemma = [ # https://huggingface.co/google/codegemma-7b-it/blob/main/config.json dict( name="CodeGemma-7b-it", hf_config=dict(org="google", name="codegemma-7b-it"), scale_embeddings=True, vocab_size=256000, padding_multiple=64, n_embd=3072, n_layer=28, n_head=16, head_size=256, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="GemmaMLP", gelu_approximate="tanh", intermediate_size=24576, ), ] configs.extend(codegemma) ########################## # Stability AI FreeWilly2 ########################## freewilly_2 = [ # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json dict( name="FreeWilly2", hf_config=dict(org="stabilityai", name="FreeWilly2"), vocab_size=32000, padding_multiple=64, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, ) ] configs.extend(freewilly_2) ################## # Meta Code Llama ################## code_llama = [ # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json dict( name="CodeLlama-7b-hf", hf_config=dict(org="codellama", name="CodeLlama-7b-hf"), block_size=16384, vocab_size=32016, padding_multiple=16, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json dict( name="CodeLlama-13b-hf", hf_config=dict(org="codellama", name="CodeLlama-13b-hf"), block_size=16384, vocab_size=32016, padding_multiple=16, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=13824, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json dict( name="CodeLlama-34b-hf", hf_config=dict(org="codellama", name="CodeLlama-34b-hf"), block_size=16384, vocab_size=32000, padded_vocab_size=32000, n_layer=48, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=22016, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-70b-hf/blob/main/config.json dict( name="CodeLlama-70b-hf", hf_config=dict(org="codellama", name="CodeLlama-70b-hf"), block_size=16384, vocab_size=32016, padding_multiple=16, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json dict( name="CodeLlama-7b-Python-hf", hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"), block_size=16384, vocab_size=32000, padded_vocab_size=32000, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json dict( name="CodeLlama-13b-Python-hf", hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"), block_size=16384, vocab_size=32000, padded_vocab_size=32000, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=13824, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json dict( name="CodeLlama-34b-Python-hf", hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"), block_size=16384, vocab_size=32000, padded_vocab_size=32000, n_layer=48, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=22016, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-70b-Python-hf/blob/main/config.json dict( name="CodeLlama-70b-Python-hf", hf_config=dict(org="codellama", name="CodeLlama-70b-Python-hf"), block_size=16384, vocab_size=32016, padding_multiple=16, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json dict( name="CodeLlama-7b-Instruct-hf", hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"), block_size=16384, vocab_size=32016, padding_multiple=16, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json dict( name="CodeLlama-13b-Instruct-hf", hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"), block_size=2048, vocab_size=32016, padding_multiple=16, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=13824, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json dict( name="CodeLlama-34b-Instruct-hf", hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"), block_size=16384, vocab_size=32000, padded_vocab_size=32000, n_layer=48, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=22016, rope_base=1000000, ), # https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/config.json dict( name="CodeLlama-70b-Instruct-hf", hf_config=dict(org="codellama", name="CodeLlama-70b-Instruct-hf"), block_size=16384, # 32016 is an added token, so not reported in vocab_size # https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/tokenizer_config.json vocab_size=32015, padding_multiple=16, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=1000000, ), ] configs.extend(code_llama) ######################## # garage-bAInd Platypus ######################## platypus = [ # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json dict( name="Platypus-30B", hf_config=dict(org="garage-bAInd", name="Platypus-30B"), block_size=2048, padded_vocab_size=32000, n_layer=60, n_head=52, n_embd=6656, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-06, mlp_class_name="LLaMAMLP", intermediate_size=17920, ), # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json dict( name="Platypus2-7B", hf_config=dict(org="garage-bAInd", name="Platypus2-7B"), padded_vocab_size=32000, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=11008, ), # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json dict( name="Platypus2-13B", hf_config=dict(org="garage-bAInd", name="Platypus2-13B"), padded_vocab_size=32000, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=13824, ), # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json dict( name="Platypus2-70B", hf_config=dict(org="garage-bAInd", name="Platypus2-70B"), padded_vocab_size=32000, n_layer=80, n_head=64, n_embd=8192, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, ), # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json dict( name="Camel-Platypus2-13B", hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"), padded_vocab_size=32000, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, ), # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json dict( name="Camel-Platypus2-70B", hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"), padded_vocab_size=32000, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, ), # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json dict( name="Stable-Platypus2-13B", hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"), padded_vocab_size=32000, n_layer=40, n_head=40, n_embd=5120, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, ), # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json dict( name="Platypus2-70B-instruct", hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"), padded_vocab_size=32000, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, ), ] configs.extend(platypus) ################################## # togethercomputer LLaMA-2-7B-32K ################################## together_llama2_32k = [ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json dict( name="LLaMA-2-7B-32K", hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"), vocab_size=32000, padding_multiple=64, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_condense_ratio=8, ) ] configs.extend(together_llama2_32k) ################ # Microsoft Phi ################ phi = [ # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json dict( name="phi-1_5", hf_config=dict(org="microsoft", name="phi-1_5"), vocab_size=50257, padded_vocab_size=51200, block_size=2048, n_embd=2048, n_layer=24, rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64 shared_attention_norm=True, lm_head_bias=True, gelu_approximate="tanh", ), # https://huggingface.co/microsoft/phi-2/blob/main/config.json dict( name="phi-2", hf_config=dict(org="microsoft", name="phi-2"), vocab_size=50257, padded_vocab_size=51200, block_size=2048, n_embd=2560, n_layer=32, rotary_percentage=0.4, # 32 / (n_embd / n_head) = 32 / 80 shared_attention_norm=True, lm_head_bias=True, gelu_approximate="tanh", ), # https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json dict( name="Phi-3-mini-4k-instruct", hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"), vocab_size=32000, padded_vocab_size=32064, block_size=4096, n_embd=3072, n_layer=32, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, sliding_window_size=2048, ), # https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json dict( name="Phi-3-mini-128k-instruct", hf_config=dict(org="microsoft", name="Phi-3-mini-128k-instruct"), vocab_size=32000, padded_vocab_size=32064, block_size=131072, n_embd=3072, n_layer=32, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, sliding_window_size=262145, ), # https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json dict( name="Phi-3.5-mini-instruct", hf_config=dict(org="microsoft", name="Phi-3.5-mini-instruct"), vocab_size=32000, padded_vocab_size=32064, block_size=4096, n_embd=3072, n_layer=32, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, ), # https://huggingface.co/microsoft/phi-4/blob/main/config.json dict( name="phi-4", hf_config=dict(org="microsoft", name="phi-4"), vocab_size=100352, padded_vocab_size=100352, block_size=16384, n_embd=5120, n_layer=40, n_head=40, n_query_groups=10, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=17920, rope_base=250000, mlp_class_name="LLaMAMLP", parallel_residual=False, ), # https://huggingface.co/microsoft/Phi-4-reasoning/blob/main/config.json dict( name="Phi-4-reasoning", hf_config=dict(org="microsoft", name="Phi-4-reasoning"), vocab_size=100352, padded_vocab_size=100352, block_size=32768, n_embd=5120, n_layer=40, n_head=40, n_query_groups=10, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=17920, rope_base=500000, mlp_class_name="LLaMAMLP", parallel_residual=False, ), # https://huggingface.co/microsoft/Phi-4-reasoning-plus/blob/main/config.json dict( name="Phi-4-reasoning-plus", hf_config=dict(org="microsoft", name="Phi-4-reasoning-plus"), vocab_size=100352, padded_vocab_size=100352, block_size=32768, n_embd=5120, n_layer=40, n_head=40, n_query_groups=10, rotary_percentage=1.0, bias=False, norm_class_name="RMSNorm", intermediate_size=17920, rope_base=500000, mlp_class_name="LLaMAMLP", parallel_residual=False, ), # https://huggingface.co/microsoft/Phi-4-mini-instruct/blob/main/config.json dict( name="Phi-4-mini-instruct", hf_config=dict(org="microsoft", name="Phi-4-mini-instruct"), vocab_size=200019, padded_vocab_size=200064, block_size=131072, n_embd=3072, n_layer=32, n_head=24, n_query_groups=8, rotary_percentage=0.75, bias=False, norm_class_name="RMSNorm", intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, sliding_window_size=262145, ), # https://huggingface.co/microsoft/Phi-4-mini-reasoning/blob/main/config.json dict( name="Phi-4-mini-reasoning", hf_config=dict(org="microsoft", name="Phi-4-mini-reasoning"), vocab_size=200019, padded_vocab_size=200064, block_size=131072, n_embd=3072, n_layer=32, n_head=24, n_query_groups=8, rotary_percentage=0.75, bias=False, norm_class_name="RMSNorm", intermediate_size=8192, mlp_class_name="LLaMAMLP", parallel_residual=False, sliding_window_size=262145, ), ] configs.extend(phi) ############# # Mistral AI ############# configs.append( # https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json dict( name="Mathstral-7B-v0.1", hf_config=dict(org="mistralai", name="mathstral-7B-v0.1"), padded_vocab_size=32768, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, sliding_window_size=4096, ) ) mistral = [ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json dict( name="Mistral-7B-{}v0.1", hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"), padded_vocab_size=32000, block_size=4096, # should be 32768 but sliding window attention is not implemented n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, sliding_window_size=4096, ), # https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json dict( name="Mixtral-8x7B-{}v0.1", hf_config=dict(org="mistralai", name="Mixtral-8x7B-{}v0.1"), padded_vocab_size=32000, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMoE", intermediate_size=14336, rope_base=1000000, n_expert=8, n_expert_per_token=2, ), # https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json dict( name="Mixtral-8x22B-{}v0.1", hf_config=dict(org="mistralai", name="Mixtral-8x22B-{}v0.1"), padded_vocab_size=32768, block_size=65536, n_layer=56, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMoE", intermediate_size=16384, n_head=48, n_embd=6144, rope_base=1000000, n_expert=8, n_expert_per_token=2, ), ] for c in mistral: for kind in ("", "Instruct-"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) configs.append( # https://huggingface.co/unsloth/mistral-7b-v0.2/blob/main/config.json dict( name="Mistral-7B-v0.2", hf_config=dict(org="unsloth", name="Mistral-7B-v0.2"), padded_vocab_size=32000, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, ) ) configs.append( # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json dict( name="Mistral-7B-Instruct-v0.2", hf_config=dict(org="mistralai", name="Mistral-7B-Instruct-v0.2"), padded_vocab_size=32000, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, ) ) configs.append( # https://huggingface.co/mistralai/Mistral-7B-v0.3/blob/main/config.json dict( name="Mistral-7B-v0.3", hf_config=dict(org="mistralai", name="Mistral-7B-v0.3"), padded_vocab_size=32768, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, ) ) configs.append( # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/config.json dict( name="Mistral-7B-Instruct-v0.3", hf_config=dict(org="mistralai", name="Mistral-7B-Instruct-v0.3"), padded_vocab_size=32768, block_size=32768, n_layer=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=14336, ) ) configs.append( # https://huggingface.co/mistralai/Mistral-Large-Instruct-2407/blob/main/config.json dict( name="Mistral-Large-Instruct-2407", hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2407"), padded_vocab_size=32768, block_size=32768, n_layer=88, n_head=96, n_embd=12288, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=28672, ) ) configs.append( # https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/config.json dict( name="Mistral-Large-Instruct-2411", hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2411"), padded_vocab_size=32768, block_size=32768, n_layer=88, n_head=96, n_embd=12288, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", norm_eps=1e-05, mlp_class_name="LLaMAMLP", intermediate_size=28672, ) ) ############ # TinyLlama ############ tiny_llama = [ dict( name="tiny-llama-1.1b{}", hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"), block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=22, n_head=32, n_embd=2048, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", # original TinyLlama use FusedRMSNorm norm_eps=1e-5, mlp_class_name="LLaMAMLP", intermediate_size=5632, n_query_groups=4, ) ] for c in tiny_llama: for kind, hf_postfix in (("", "-intermediate-step-1431k-3T"), ("-chat", "-Chat-v1.0")): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix) configs.append(copy) ############ # MicroLlama ############ micro_llama = [ dict( name="micro-llama-300M", hf_config=dict(org="keeeeenw", name="MicroLlama"), block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=12, n_head=16, n_embd=1024, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", # original TinyLlama and MicroLlama use FusedRMSNorm norm_eps=1e-5, mlp_class_name="LLaMAMLP", intermediate_size=5632, n_query_groups=4, ) ] configs.extend(micro_llama) ########################## # Trelis Function Calling ########################## llama_2_function_calling = [ # https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2/blob/main/config.json dict( name="Llama-2-7b-chat-hf-function-calling-v2", hf_config=dict(org="Trelis", name="Llama-2-7b-chat-hf-function-calling-v2"), padding_multiple=64, n_layer=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, norm_eps=1e-6, block_size=4096, vocab_size=32000, n_head=32, n_embd=4096, rope_base=10000, ) ] configs.extend(llama_2_function_calling) ########## # Qwen2.5 ########## qwen_2_5 = [ # https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json dict( name="Qwen2.5-0.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-0.5B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=24, n_head=14, n_embd=896, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=4864, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-1.5B/blob/main/config.json dict( name="Qwen2.5-1.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-1.5B{}"), block_size=131072, vocab_size=151643, padded_vocab_size=151936, n_layer=28, n_head=12, n_embd=1536, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8960, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json dict( name="Qwen2.5-3B{}", hf_config=dict(org="Qwen", name="Qwen2.5-3B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=36, n_head=16, n_embd=2048, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json dict( name="Qwen2.5-7B{}", hf_config=dict(org="Qwen", name="Qwen2.5-7B{}"), block_size=131072, vocab_size=151643, padded_vocab_size=152064, n_layer=28, n_head=28, n_embd=3584, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=18944, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-14B/blob/main/config.json dict( name="Qwen2.5-14B{}", hf_config=dict(org="Qwen", name="Qwen2.5-14B{}"), block_size=131072, vocab_size=151643, padded_vocab_size=152064, n_layer=48, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, norm_eps=1e-5, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-32B/blob/main/config.json dict( name="Qwen2.5-32B{}", hf_config=dict(org="Qwen", name="Qwen2.5-32B{}"), block_size=131072, vocab_size=151643, padded_vocab_size=152064, n_layer=64, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=27648, norm_eps=1e-5, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-72B/blob/main/config.json dict( name="Qwen2.5-72B{}", hf_config=dict(org="Qwen", name="Qwen2.5-72B{}"), block_size=131072, vocab_size=151643, padded_vocab_size=152064, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=29568, norm_eps=1e-5, rope_base=1000000, ), ] qwen_2_5_coder = [ # https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B/blob/main/config.json dict( name="Qwen2.5-Coder-0.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-0.5B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=24, n_head=14, n_embd=896, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=4864, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B/blob/main/config.json dict( name="Qwen2.5-Coder-1.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=28, n_head=12, n_embd=1536, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8960, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-Coder-3B/blob/main/config.json dict( name="Qwen2.5-Coder-3B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-3B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=36, n_head=16, n_embd=2048, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-Coder-7B/blob/main/config.json dict( name="Qwen2.5-Coder-7B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=28, n_head=28, n_embd=3584, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=18944, norm_eps=1e-6, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-Coder-14B/blob/main/config.json dict( name="Qwen2.5-Coder-14B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=48, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, norm_eps=1e-5, rope_base=1000000, ), # https://huggingface.co/Qwen/Qwen2.5-Coder-32B/blob/main/config.json dict( name="Qwen2.5-Coder-32B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"), block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=64, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=27648, norm_eps=1e-5, rope_base=1000000, ), ] qwen_2_5.extend(qwen_2_5_coder) qwen_2_5_math = [ # https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json dict( name="Qwen2.5-Math-1.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Math-1.5B{}"), block_size=4096, vocab_size=151643, padded_vocab_size=151936, n_layer=28, n_head=12, n_embd=1536, n_query_groups=2, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8960, norm_eps=1e-6, rope_base=10000, ), # https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json dict( name="Qwen2.5-Math-7B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Math-7B{}"), block_size=4096, vocab_size=151643, padded_vocab_size=152064, n_layer=28, n_head=28, n_embd=3584, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=18944, norm_eps=1e-6, rope_base=10000, ), # https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json dict( name="Qwen2.5-Math-72B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Math-72B{}"), block_size=4096, vocab_size=151643, padded_vocab_size=152064, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=29568, norm_eps=1e-5, rope_base=10000, ), ] qwen_2_5.extend(qwen_2_5_math) for c in qwen_2_5: for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) qwen_2_5_1m = [ # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-1M/blob/main/config.json dict( name="Qwen2.5-7B-Instruct-1M", hf_config=dict(org="Qwen", name="Qwen2.5-7B-Instruct-1M"), block_size=1010000, vocab_size=151643, padded_vocab_size=152064, n_layer=28, n_head=28, n_embd=3584, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=18944, norm_eps=1e-5, rope_base=10000000, ), # https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json dict( name="Qwen2.5-14B-Instruct-1M", hf_config=dict(org="Qwen", name="Qwen2.5-14B-Instruct-1M"), block_size=1010000, vocab_size=151643, padded_vocab_size=152064, n_layer=48, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=13824, norm_eps=1e-5, rope_base=10000000, ), ] configs.extend(qwen_2_5_1m) ########## # QwQ ########## qwq = [ # https://huggingface.co/Qwen/QwQ-32B/blob/main/config.json dict( name="QwQ-32B", hf_config=dict(org="Qwen", name="QwQ-32B"), block_size=131072, vocab_size=151643, padded_vocab_size=152064, n_layer=64, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=27648, norm_eps=1e-5, rope_base=1000000, ), # https://huggingface.co/Qwen/QwQ-32B-Preview/blob/main/config.json dict( name="QwQ-32B-Preview", hf_config=dict(org="Qwen", name="QwQ-32B-Preview"), block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=64, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, attn_bias=True, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=27648, norm_eps=1e-5, rope_base=1000000, ), ] configs.extend(qwq) ########## # Qwen3 ########## qwen_3 = [ # https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json dict( name="Qwen3-0.6B{}", hf_config=dict(org="Qwen", name="Qwen3-0.6B{}"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=28, n_head=16, n_embd=1024, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=3072, norm_eps=1e-6, rope_base=1000000, head_size=128, norm_qk=True, ), # https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json dict( name="Qwen3-1.7B{}", hf_config=dict(org="Qwen", name="Qwen3-1.7B{}"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=28, n_head=16, n_embd=2048, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=6144, norm_eps=1e-6, rope_base=1000000, norm_qk=True, ), # https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json dict( name="Qwen3-4B{}", hf_config=dict(org="Qwen", name="Qwen3-4B{}"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=36, n_head=32, n_embd=2560, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=9728, norm_eps=1e-6, rope_base=1000000, head_size=128, norm_qk=True, ), # https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json dict( name="Qwen3-8B{}", hf_config=dict(org="Qwen", name="Qwen3-8B{}"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=36, n_head=32, n_embd=4096, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=12288, norm_eps=1e-6, rope_base=1000000, norm_qk=True, ), # https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json dict( name="Qwen3-14B{}", hf_config=dict(org="Qwen", name="Qwen3-14B{}"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=40, n_head=40, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=17408, norm_eps=1e-6, rope_base=1000000, norm_qk=True, ), ] for c in qwen_3: for kind in ("", "-Base"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) qwen_3_32b = [ # https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json dict( name="Qwen3-32B", hf_config=dict(org="Qwen", name="Qwen3-32B"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, n_layer=64, n_head=64, n_embd=5120, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=25600, norm_eps=1e-6, rope_base=1000000, head_size=128, norm_qk=True, ), ] configs.extend(qwen_3_32b) qwen_3_moe = [ # https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json dict( name="Qwen3-30B-A3B", hf_config=dict(org="Qwen", name="Qwen3-30B-A3B"), block_size=40960, head_size=128, vocab_size=151643, padded_vocab_size=151936, n_layer=48, n_head=32, n_embd=2048, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMoE", intermediate_size=6144, moe_intermediate_size=768, norm_eps=1e-6, rope_base=1000000, norm_qk=True, n_expert=128, n_expert_per_token=8, ), # https://huggingface.co/Qwen/Qwen3-30B-A3B-Base/blob/main/config.json dict( name="Qwen3-30B-A3B-Base", hf_config=dict(org="Qwen", name="Qwen3-30B-A3B-Base"), block_size=40960, head_size=128, vocab_size=151643, padded_vocab_size=151936, n_layer=48, n_head=32, n_embd=2048, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMoE", intermediate_size=6144, moe_intermediate_size=768, norm_eps=1e-6, rope_base=1000000, norm_qk=True, n_expert=128, n_expert_per_token=8, ), # https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json dict( name="Qwen3-235B-A22B", hf_config=dict(org="Qwen", name="Qwen3-235B-A22B"), block_size=40960, head_size=128, vocab_size=151643, padded_vocab_size=151936, n_layer=94, n_head=64, n_embd=4096, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMoE", intermediate_size=12288, moe_intermediate_size=1536, norm_eps=1e-6, rope_base=1000000, norm_qk=True, n_expert=128, n_expert_per_token=8, ), ] configs.extend(qwen_3_moe) qwen_3_2507_thinking_instruct = [ # https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507/blob/main/config.json dict( name="Qwen3-235B-A22B-{}-2507", hf_config=dict(org="Qwen", name="Qwen3-235B-A22B-{}-2507"), block_size=262144, head_size=128, vocab_size=151643, padded_vocab_size=151936, n_layer=94, n_head=64, n_embd=4096, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMoE", intermediate_size=12288, moe_intermediate_size=1536, norm_eps=1e-6, rope_base=5000000, norm_qk=True, n_expert=128, n_expert_per_token=8, ), # https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507/blob/main/config.json dict( name="Qwen3-30B-A3B-{}-2507", hf_config=dict(org="Qwen", name="Qwen3-30B-A3B-{}-2507"), block_size=262144, head_size=128, vocab_size=151643, padded_vocab_size=151936, n_layer=48, n_head=32, n_embd=2048, n_query_groups=4, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMoE", intermediate_size=6144, moe_intermediate_size=768, norm_eps=1e-6, rope_base=10000000, norm_qk=True, n_expert=128, n_expert_per_token=8, ), # https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507/blob/main/config.json dict( name="Qwen3-4B-{}-2507", hf_config=dict(org="Qwen", name="Qwen3-4B-{}-2507"), block_size=262144, vocab_size=151643, padded_vocab_size=151936, n_layer=36, n_head=32, n_embd=2560, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=9728, norm_eps=1e-6, rope_base=5000000, head_size=128, norm_qk=True, ), ] for c in qwen_3_2507_thinking_instruct: for kind in ("Thinking", "Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############# # Salamandra ############# salamandra = [ # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json dict( name="salamandra-2b{}", hf_config=dict(org="BSC-LT", name="salamandra-2b{}"), block_size=8192, vocab_size=256000, padded_vocab_size=256000, n_layer=24, n_head=16, n_embd=2048, n_query_groups=16, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=5440, norm_eps=1e-5, rope_base=10000, ), # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json dict( name="salamandra-7b{}", hf_config=dict(org="BSC-LT", name="salamandra-7b{}"), block_size=8192, vocab_size=256000, padded_vocab_size=256000, n_layer=32, n_head=32, n_embd=4096, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=11008, norm_eps=1e-6, rope_base=10000, ), ] for c in salamandra: for kind in ("", "-instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############### # SmolLM2 ############### smollm2 = [ # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json dict( name="SmolLM2-135M{}", hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"), block_size=8192, vocab_size=49152, padded_vocab_size=49152, n_layer=30, n_head=9, n_embd=576, n_query_groups=3, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=1536, rope_base=100000, norm_eps=1e-5, ), # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json dict( name="SmolLM2-360M{}", hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"), block_size=8192, vocab_size=49152, padded_vocab_size=49152, n_layer=32, n_head=15, n_embd=960, n_query_groups=5, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=2560, rope_base=100000, norm_eps=1e-5, ), # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json dict( name="SmolLM2-1.7B{}", hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"), block_size=8192, vocab_size=49152, padded_vocab_size=49152, n_layer=24, n_head=32, n_embd=2048, n_query_groups=32, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=8192, rope_base=130000, norm_eps=1e-5, ), ] for c in smollm2: for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) ############### # DeepSeek R1 Distill ############### r1_distill_llama = [ # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/blob/main/config.json dict( name="R1-Distill-Llama-8B", hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-8B"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=32, n_head=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=14336, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/blob/main/config.json dict( name="R1-Distill-Llama-70B", hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-70B"), block_size=131072, vocab_size=128000, padded_vocab_size=128256, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192), ), ] configs.extend(r1_distill_llama) name_to_config = {config["name"]: config for config in configs} ================================================ FILE: litgpt/constants.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Centralized package availability constants for optional dependencies.""" from lightning_utilities.core.imports import RequirementCache # Logger-related constants _SUPPORTED_LOGGERS: tuple[str, ...] = ("csv", "tensorboard", "wandb", "mlflow", "litlogger") # Logger-related optional dependencies _LITLOGGER_AVAILABLE = RequirementCache("litlogger>=0.1.7") _TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") _WANDB_AVAILABLE = RequirementCache("wandb") _MLFLOW_AVAILABLE = RequirementCache("mlflow") _MLFLOW_SKINNY_AVAILABLE = RequirementCache("mlflow-skinny") # PyTorch version-specific constants _TORCH_EQUAL_2_7 = RequirementCache("torch>=2.7.0,<2.8") _TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9") # Other optional dependencies _REQUESTS_AVAILABLE = RequirementCache("requests") _THUNDER_AVAILABLE = RequirementCache("thunder") _TRITON_AVAILABLE = RequirementCache("triton") _BITANDBYTES_AVAILABLE = RequirementCache("bitsandbytes") _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 = RequirementCache("bitsandbytes != 0.42.0") _LITDATA_AVAILABLE = RequirementCache("litdata") _LITSERVE_AVAILABLE = RequirementCache("litserve") _JINJA2_AVAILABLE = RequirementCache("jinja2") _SAFETENSORS_AVAILABLE = RequirementCache("safetensors") _HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer") ================================================ FILE: litgpt/data/__init__.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from litgpt.data.alpaca import Alpaca from litgpt.data.alpaca_2k import Alpaca2k from litgpt.data.alpaca_gpt4 import AlpacaGPT4 from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn from litgpt.data.deita import Deita from litgpt.data.flan import FLAN from litgpt.data.json_data import JSON from litgpt.data.lima import LIMA from litgpt.data.lit_data import LitData from litgpt.data.longform import LongForm from litgpt.data.microllama import MicroLlama from litgpt.data.openwebtext import OpenWebText from litgpt.data.text_files import TextFiles from litgpt.data.tinyllama import TinyLlama from litgpt.data.tinystories import TinyStories __all__ = [ "Alpaca", "Alpaca2k", "AlpacaGPT4", "Deita", "FLAN", "JSON", "LIMA", "LitData", "DataModule", "LongForm", "OpenWebText", "SFTDataset", "TextFiles", "TinyLlama", "TinyStories", "MicroLlama", "get_sft_collate_fn", ] ================================================ FILE: litgpt/data/alpaca.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation derived from https://github.com/tloen/alpaca-lora""" import json from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union import torch from torch.utils.data import DataLoader, random_split from litgpt.constants import _REQUESTS_AVAILABLE from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer _URL = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json" @dataclass class Alpaca(DataModule): """Alpaca data module for supervised finetuning.""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" val_split_fraction: float = 0.03865 # to get exactly 2000 validation samples, """The fraction of the dataset to use for the validation dataset. The rest is used for training.""" prompt_style: Union[str, PromptStyle] = "alpaca" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for creating the train/val splits and shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" download_dir: Path = Path("./data/alpaca") """The directory in which the downloaded dataset gets saved.""" file_url: str = field(repr=False, default=_URL) """The URL from where to download the dataset.""" file_name: str = field(repr=False, default="alpaca_data_cleaned_archive.json") """The name of the dataset file to download.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: super().__init__() if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) download_if_missing(self.download_dir / self.file_name, self.file_url) def setup(self, stage: str = "") -> None: with open(self.download_dir / self.file_name, encoding="utf-8") as file: data = json.load(file) # Partition the dataset into train and test train_data, test_data = random_split( data, [1.0 - self.val_split_fraction, self.val_split_fraction], generator=torch.Generator().manual_seed(self.seed), ) train_data, test_data = list(train_data), list(test_data) self.train_dataset = SFTDataset( data=train_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) self.test_dataset = SFTDataset( data=test_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def val_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def download_if_missing(file_path: Path, file_url: str, mode: str = "w", stream: bool = False) -> None: """Downloads the raw json data file and saves it in the given destination.""" if file_path.exists() and file_path.stat().st_size > 0: return if not _REQUESTS_AVAILABLE: raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE)) import requests response = requests.get(file_url, stream=stream) with open(file_path, mode, encoding=None if mode == "wb" else "utf-8") as f: if stream: # credit: https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py#L25-L38 from tqdm import tqdm pbar = tqdm( desc=str(file_path), total=int(response.headers.get("content-length", 0)), unit="iB", unit_scale=True, unit_divisor=1024, ) for data in response.iter_content(chunk_size=1024): size = f.write(data) pbar.update(size) pbar.close() else: f.write(response.text) ================================================ FILE: litgpt/data/alpaca_2k.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from dataclasses import dataclass, field from pathlib import Path from litgpt.data.alpaca import Alpaca from litgpt.data.base import SFTDataset @dataclass class Alpaca2k(Alpaca): """Alpaca2k data module for supervised finetuning.""" val_split_fraction: float = 0.05 # to get exactly 100 validation samples, """The fraction of the dataset to use for the validation dataset. The rest is used for training.""" download_dir: Path = Path("./data/alpaca2k") """The directory in which the downloaded datasetgets saved.""" repo_id: str = field(repr=False, default="mhenrichsen/alpaca_2k_test") """The URL from where to download the dataset.""" file_name: str = field(repr=False, default="alpaca2k_data_cleaned_archive.json") """The name of the dataset file to download.""" def prepare_data(self) -> None: from datasets import load_dataset load_dataset(self.repo_id, cache_dir=self.download_dir) def setup(self, stage: str = "") -> None: from datasets import load_dataset dataset = load_dataset(self.repo_id, cache_dir=self.download_dir) train_validation_split = dataset["train"].train_test_split(test_size=self.val_split_fraction, seed=self.seed) train_data = train_validation_split["train"] test_data = train_validation_split["test"] self.train_dataset = SFTDataset( data=train_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) self.test_dataset = SFTDataset( data=test_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) ================================================ FILE: litgpt/data/alpaca_gpt4.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from dataclasses import dataclass, field from pathlib import Path from litgpt.data.alpaca import Alpaca _URL = "https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json" @dataclass class AlpacaGPT4(Alpaca): """AlpacaGPT4 data module for supervised finetuning.""" val_split_fraction: float = 0.03847 # to get exactly 2000 test samples, """The fraction of the dataset to use for the validation dataset. The rest is used for training.""" download_dir: Path = Path("./data/alpacagpt4") """The directory in which the downloaded datasetgets saved.""" file_url: str = field(repr=False, default=_URL) """The URL from where to download the dataset.""" file_name: str = field(repr=False, default="alpacagpt4_data_cleaned_archive.json") """The name of the dataset file to download.""" ================================================ FILE: litgpt/data/base.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from abc import abstractmethod from functools import partial from typing import Any, Callable, Dict, List, Optional, Union import torch from lightning import LightningDataModule from torch import Tensor from torch.utils.data import Dataset from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer class DataModule(LightningDataModule): """Base class for all data modules in LitGPT.""" @abstractmethod def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None, **kwargs, ) -> None: """All settings that can't be determined at the time of instantiation need to be passed through here before any dataloaders can be accessed. """ def setup(self, stage: str = "") -> None: # Stub is to redefine the default signature, because the concept of 'stage' does not exist in LitGPT pass def __repr__(self) -> str: return f"{self.__class__.__name__}()" class SFTDataset(Dataset): """An in-memory dataset for supervised finetuning with `input_ids` and `labels`. Args: data: A list of samples (dicts). The target/label must be stored under the key 'output' and the instruction or other data can be stored under any key as long as it is compatible with the given prompt template. tokenizer: The tokenizer to use. Should match the one that was used to pretrain the model. prompt_style: The style to apply to prompts. See `litgpt.prompts` for a list of available styles. max_seq_length: Truncate sequences that are longer than this value. By default, no truncation is applied. mask_prompt: Whether to mask the prompt section from the label (with ``ignore_index``). ignore_index: The index to use for elements to be ignored in the label. transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the keys in the dataset to the expected 'instruction' and 'output' keys. Returns a dict with two keys: input_ids: The encoded prompt + response labels: Same as input_ids, unless ``mask_prompt=True`` in which case the 'prompt' part is replaced with the ``ignore_index``. """ def __init__( self, data: List[Dict[str, str]], tokenizer: Tokenizer, prompt_style: Union[str, PromptStyle], max_seq_length: int = -1, mask_prompt: bool = True, ignore_index: int = -100, transform: Optional[Callable[[Any], Any]] = None, ) -> None: self.data = data self.tokenizer = tokenizer self.prompt_style = ( prompt_style if isinstance(prompt_style, PromptStyle) else PromptStyle.from_name(prompt_style) ) self.max_seq_length = max_seq_length self.mask_prompt = mask_prompt self.ignore_index = ignore_index self.transform = transform def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Union[Tensor, Dict[str, int]]]: example = self.data[idx] if self.transform is not None: example = self.transform(example) prompt = self.prompt_style.apply(prompt=example["instruction"], **example) encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length) encoded_response = self.tokenizer.encode(example["output"], bos=False, eos=True, max_length=self.max_seq_length) encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64) if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1 encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length] # The labels are the full prompt with response, but with the prompt masked out labels = encoded_prompt_and_response.clone() if self.mask_prompt: labels[: len(encoded_prompt)] = self.ignore_index raw_token_count = len(self.tokenizer.encode(example["instruction"], max_length=self.max_seq_length)) + len( encoded_response ) return { "input_ids": encoded_prompt_and_response, "labels": labels, "token_counts": { "raw": raw_token_count, "raw_plus_prompt_template": len(encoded_prompt_and_response), }, } def get_sft_collate_fn(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100): """Returns the collate function for supervised finetuning (needed in the DataLoader). The collate function gets a list of dicts with keys `input_ids` and `labels`. It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in the batch. Optionally truncates all sequences to the specified maximum length. """ return partial(_sft_collate_fn, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index) def _sft_collate_fn( samples: List[Dict[str, Tensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100 ) -> Dict[str, Tensor]: batched = {} for key in ("input_ids", "labels"): pad_value = pad_id if key == "input_ids" else ignore_index # Pad right based on the longest sequence batched[key] = torch.nn.utils.rnn.pad_sequence( [sample[key] for sample in samples], batch_first=True, padding_value=pad_value ) # Truncate if needed if max_seq_length > 0: batched[key] = batched[key][:, :max_seq_length] batched["token_counts"] = {} batched["token_counts"]["raw"] = torch.tensor( # Token count without padding and without prompt template [sample["token_counts"]["raw"] for sample in samples], dtype=torch.int64 ).unsqueeze(1) batched["token_counts"]["raw_plus_prompt_template"] = ( torch.tensor( # Token count without padding but with prompt template [sample["token_counts"]["raw_plus_prompt_template"] for sample in samples], dtype=torch.int64 ).unsqueeze(1) ) return batched ================================================ FILE: litgpt/data/deita.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation derived from https://github.com/tloen/alpaca-lora""" from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Union import torch from torch.utils.data import DataLoader from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer @dataclass class Deita(DataModule): """Deita data module for supervised finetuning.""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" prompt_style: Union[str, PromptStyle] = "alpaca" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" include_multiturn_conversations: bool = False """Whether to include multi-turn conversations in the dataset.""" download_dir: Path = Path("./data/deita") """The directory in which the downloaded dataset gets saved.""" repo_id: str = "HuggingFaceH4/deita-10k-v0-sft" """The repo from where the data is downloaded""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: super().__init__() if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def prepare_data(self) -> None: from datasets import load_dataset load_dataset(self.repo_id, split=["train_sft", "test_sft"], cache_dir=self.download_dir) def setup(self, stage: str = "") -> None: from datasets import load_dataset dataset = load_dataset(self.repo_id, split=["train_sft", "test_sft"]) train_data = format_dataset(dataset[0], self.include_multiturn_conversations) test_data = format_dataset(dataset[1], self.include_multiturn_conversations) self.train_dataset = SFTDataset( data=train_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) self.test_dataset = SFTDataset( data=test_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def val_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def format_dataset(dataset: List[dict], include_multi_turn_conversations: bool) -> List[dict]: formatted = [] for entry in dataset: convo = entry["messages"] if include_multi_turn_conversations: for i in range(0, len(convo) - 1, 2): formatted.append({"instruction": convo[i]["content"], "input": "", "output": convo[i + 1]["content"]}) else: formatted.append({"instruction": convo[0]["content"], "input": "", "output": convo[1]["content"]}) return formatted ================================================ FILE: litgpt/data/flan.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Set, Union import torch from torch.utils.data import DataLoader from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.data.alpaca import download_if_missing from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer _URL = "https://huggingface.co/datasets/Muennighoff/flan/resolve/main" # TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache # on disk or use Lightning Data @dataclass class FLAN(DataModule): """FLAN data module for supervised finetuning.""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" prompt_style: Union[str, PromptStyle] = "flan" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" download_dir: Path = Path("./data/flan") """The directory in which the downloaded dataset gets saved.""" url: str = _URL """The URL from where to download the dataset.""" subsets: Optional[str] = None """A comma separated list of subsets to use. If None, all subsets are used.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self): super().__init__() if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) supported_subsets = _supported_subsets() if self.subsets is not None: self.subsets = self.subsets.split(",") for subset in self.subsets: if subset not in supported_subsets: raise ValueError(f"{subset} not in {supported_subsets}") else: self.subsets = list(supported_subsets) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) for subset in self.subsets: for split in ("train", "test"): data_file_path = self.download_dir / f"{subset}_{split}.jsonl" data_file_url = f"{self.url}/{split}/{subset}_{split}.jsonl" download_if_missing(data_file_path, data_file_url) def train_dataloader(self): return self._dataloader("train") def val_dataloader(self): return self._dataloader("test") def _dataloader(self, split: str) -> DataLoader: data = [] for subset in self.subsets: data_file_path = self.download_dir / f"{subset}_{split}.jsonl" data.extend(load_jsonl(data_file_path)) dataset = SFTDataset( data=data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, transform=_transform, ) return DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def load_jsonl(filename: Path) -> List[Dict[str, str]]: data = [] with open(filename, encoding="utf-8") as f: for line in f: data.append(json.loads(line)) return data def _transform(item: dict) -> dict: item["instruction"] = item.pop("inputs") item["output"] = item.pop("targets") return item def _supported_subsets() -> Set[str]: return { "aeslc_10templates", "ag_news_subset_10templates", "anli_r1_10templates", "anli_r2_10templates", "anli_r3_10templates", "arc_challenge_10templates", "arc_easy_10templates", "bool_q_10templates", "cb_10templates", "cnn_dailymail_10templates", "cola_10templates", "common_gen_10templates", "copa_10templates", "coqa_10templates", "cosmos_qa_10templates", "dart_10templates", "definite_pronoun_resolution_10templates", "drop_10templates", "e2e_nlg_10templates", "fix_punct_10templates", "gigaword_10templates", "glue_mrpc_10templates", "glue_qqp_10templates", "hellaswag_10templates", "imdb_reviews_10templates", "math_dataset_10templates", "mnli_matched_10templates", "mnli_mismatched_10templates", "multi_news_10templates", "multirc_10templates", "natural_questions_10templates", "openbookqa_10templates", "opinion_abstracts_idebate_10templates", "opinion_abstracts_rotten_tomatoes_10templates", "para_crawl_enes_10templates", "paws_wiki_10templates", "piqa_10templates", "qnli_10templates", "quac_10templates", "record_10templates", "rte_10templates", "samsum_10templates", "sentiment140_10templates", "snli_10templates", "squad_v1_10templates", "squad_v2_10templates", "sst2_10templates", "story_cloze_10templates", "stsb_10templates", "trec_10templates", "trivia_qa_10templates", "true_case_10templates", "web_nlg_en_10templates", "wic_10templates", "wiki_lingua_english_en_10templates", "wmt14_enfr_10templates", "wmt16_translate_csen_10templates", "wmt16_translate_deen_10templates", "wmt16_translate_fien_10templates", "wmt16_translate_roen_10templates", "wmt16_translate_ruen_10templates", "wmt16_translate_tren_10templates", "wnli_10templates", "word_segment_10templates", "wsc_10templates", "yelp_polarity_reviews_10templates", } ================================================ FILE: litgpt/data/json_data.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional, Tuple, Union import torch from torch.utils.data import DataLoader, random_split from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer @dataclass class JSON(DataModule): """Loads JSON or JSONL data for supervised finetuning.""" json_path: Path """A path to a JSON file or a directory with `train.json` and `val.json` containing the data. The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input' (see Alpaca).""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" val_split_fraction: Optional[float] = None """The fraction of the dataset to use for the validation dataset. The rest is used for training. Only applies if you passed in a single file to `json_path`.""" prompt_style: Union[str, PromptStyle] = "alpaca" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for creating the train/val splits and shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self): super().__init__() if self.json_path.is_file() and self.val_split_fraction is None: self.val_split_fraction = 0.05 warnings.warn( "The `json_path` points to a single file and `val_split_fraction` was not set. " "Defaulting to `val_split_fraction=0.05`. Set `val_split_fraction` explicitly " "to use a different split percentage.", UserWarning, stacklevel=2, ) if self.json_path.is_dir() and self.val_split_fraction is not None: raise ValueError( "If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and" f" hence `val_split_fraction` should not be set. Got `{self.val_split_fraction=}`." ) if not self.json_path.exists(): raise FileNotFoundError( "The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files," f" but '{self.json_path!s}' does not exist." ) if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def setup(self, stage: str = "") -> None: train_data, test_data = self.get_splits() self.train_dataset = SFTDataset( data=train_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) self.test_dataset = SFTDataset( data=test_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def val_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def get_splits(self) -> Tuple: # A single file (gets split into train and test) if self.json_path.is_file(): data = load_split(self.json_path) # Partition the dataset into train and test train_data, test_data = random_split( data, [1.0 - self.val_split_fraction, self.val_split_fraction], generator=torch.Generator().manual_seed(self.seed), ) return train_data, test_data # A directory containing train.json and val.json if (train_file := self.find_split("train")) and (val_file := self.find_split("val")): train_data = load_split(train_file) test_data = load_split(val_file) return train_data, test_data raise FileNotFoundError( "The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files." ) def find_split(self, split_name: str) -> Optional[Path]: for suffix in (".json", ".jsonl"): if (file := self.json_path / f"{split_name}{suffix}").is_file(): return file return None def load_split(json_path: Path) -> Any: if json_path.suffix == ".json": with open(json_path, encoding="utf-8") as file: return json.load(file) if json_path.suffix == ".jsonl": with open(json_path, encoding="utf-8") as file: return [json.loads(line) for line in file] else: raise ValueError(f"Unsupported file format: {json_path.suffix}. Expected `.json` or `.jsonl`.") ================================================ FILE: litgpt/data/lima.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Implementation derived from https://github.com/tloen/alpaca-lora""" import os from dataclasses import dataclass, field from typing import List, Optional, Union import torch from torch.utils.data import DataLoader, random_split from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer @dataclass class LIMA(DataModule): """LIMA data module for supervised finetuning.""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" val_split_fraction: float = 0.1 """The fraction of the dataset to use for the validation dataset. The rest is used for training.""" prompt_style: Union[str, PromptStyle] = "alpaca" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for creating the train/val splits and shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" include_multiturn_conversations: bool = False """Whether to include multi-turn conversations in the dataset.""" repo_id: str = "GAIR/lima" """The Hugging Face dataset repository ID from where to download the data.""" access_token: Optional[str] = field(repr=False, default=os.getenv("HF_TOKEN")) """The Hugging Face API token to use for authentication. Can also be set through the `HF_TOKEN` environment variable.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self): super().__init__() if self.access_token is None: raise ValueError( "LIMA requires authentication, please set the `HF_TOKEN=your_token` environment" " variable or pass --access_token=your_token. You can find your token by visiting" " https://huggingface.co/settings/tokens" ) if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def prepare_data(self) -> None: from datasets import load_dataset load_dataset(self.repo_id, token=self.access_token) def setup(self, stage: str = "") -> None: from datasets import load_dataset dataset = load_dataset(self.repo_id, token=self.access_token) data = format_dataset(dataset["train"], self.include_multiturn_conversations) # Partition the dataset into train and test train_data, test_data = random_split( data, [1.0 - self.val_split_fraction, self.val_split_fraction], generator=torch.Generator().manual_seed(self.seed), ) train_data, test_data = list(train_data), list(test_data) self.train_dataset = SFTDataset( data=train_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) self.test_dataset = SFTDataset( data=test_data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, ) def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def val_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def format_dataset(dataset_partition: dict, include_multi_turn_conversations: bool) -> List[dict]: formatted_ds = [] for entry in dataset_partition: convo = entry["conversations"] if include_multi_turn_conversations: for i in range(0, len(convo) - 1, 2): formatted_ds.append({"instruction": convo[i], "input": "", "output": convo[i + 1]}) else: formatted_ds.append({"instruction": convo[0], "input": "", "output": convo[1]}) return formatted_ds ================================================ FILE: litgpt/data/lit_data.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Tuple, Union from torch.utils.data import DataLoader from litgpt.data import DataModule from litgpt.tokenizer import Tokenizer @dataclass class LitData(DataModule): """Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks).""" data_path: Union[str, Path] = Path("data/") """The path to the data directory containing the preprocessed chunks for the streaming dataset The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders for training- and validation splits.""" split_names: Optional[Tuple[str, str]] = None """Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided, all data under data_path will be used for training, and the validation dataloader will be identical to the train dataloader.""" seed: int = 42 """The random seed for shuffling the dataset.""" num_workers: int = 8 """How many DataLoader processes to use for loading.""" batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) def __post_init__(self) -> None: super().__init__() if self.split_names is not None and len(self.split_names) != 2: raise ValueError("If provided `split_names` must be a tuple of two strings, for example: ('train', 'val').") def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def train_dataloader(self) -> DataLoader: input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path) return self._dataloader(input_dir=input_dir, train=True) def val_dataloader(self) -> DataLoader: input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path) return self._dataloader(input_dir=input_dir, train=False) def _dataloader(self, input_dir: str, train: bool): from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader dataset = StreamingDataset( input_dir=input_dir, item_loader=TokensLoader(block_size=self.seq_length), shuffle=train, seed=self.seed, ) dataloader = StreamingDataLoader( dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return dataloader ================================================ FILE: litgpt/data/longform.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union import torch from torch.utils.data import DataLoader from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.data.alpaca import download_if_missing from litgpt.prompts import PromptStyle from litgpt.tokenizer import Tokenizer _URL = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset" @dataclass class LongForm(DataModule): """LongForm data module for supervised finetuning.""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" prompt_style: Union[str, PromptStyle] = "longform" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -100 """The index to use for elements to be ignored in the label.""" seed: int = 42 """The random seed for shuffling the dataset.""" num_workers: int = 4 """How many DataLoader processes to use for loading.""" download_dir: Path = Path("./data/longform") """The directory in which the downloaded dataset gets saved.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: super().__init__() if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) download_if_missing(self.download_dir / "train.json", f"{_URL}/train.json") download_if_missing(self.download_dir / "val.json", f"{_URL}/val.json") def train_dataloader(self): return self._dataloader("train") def val_dataloader(self): return self._dataloader("val") def _dataloader(self, split: str) -> DataLoader: with open(self.download_dir / f"{split}.json", encoding="utf-8") as file: data = json.load(file) dataset = SFTDataset( data=data, tokenizer=self.tokenizer, prompt_style=self.prompt_style, max_seq_length=self.max_seq_length, mask_prompt=self.mask_prompt, ignore_index=self.ignore_index, transform=_transform, ) return DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), ) def _transform(item: dict) -> dict: item["instruction"] = item.pop("input") return item ================================================ FILE: litgpt/data/microllama.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from dataclasses import dataclass from pathlib import Path from typing import Union from litgpt.data.tinyllama import TinyLlama @dataclass class MicroLlama(TinyLlama): """The MicroLlama data module is composed of only SlimPajama data.""" def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8): super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False) ================================================ FILE: litgpt/data/openwebtext.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Optional, Union from torch.utils.data import DataLoader from litgpt.data import DataModule from litgpt.tokenizer import Tokenizer @dataclass class OpenWebText(DataModule): """The OpenWebText data module for pretraining.""" data_path: Union[str, Path] = Path("data/openwebtext") """The path to the data directory, containing two folders 'train' and 'val' which are the output of the preprocessing step. The path can also be a remote path (e.g., s3://).""" val_split_fraction: float = 0.0005 """The fraction of data that should be put aside for validation.""" seed: int = 42 """The seed to use for shuffling the training data.""" num_workers: int = 8 """The number of workers to use for the dataloaders.""" tokenizer: Optional[Tokenizer] = field(default=None, repr=False, init=False) batch_size: int = field(default=1, repr=False, init=False) seq_length: int = field(default=2048, repr=False, init=False) def __post_init__(self) -> None: super().__init__() # Could be a remote path (s3://) or a local path self.data_path_train = str(self.data_path).rstrip("/") + "/train" self.data_path_val = str(self.data_path).rstrip("/") + "/val" def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = 2048 ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: from datasets import Dataset, load_dataset from litdata import optimize if str(self.data_path).startswith("s3://"): print(f"The OpenWebText data path points to an S3 location: {self.data_path}. Skipping preprocessing.") return if Path(self.data_path_train).is_dir() and Path(self.data_path_val).is_dir(): print(f"Found OpenWebText train and val dir: {self.data_path}. Skipping preprocessing.") return dataset = load_dataset("openwebtext", num_proc=(os.cpu_count() // 2), trust_remote_code=True) # Split the data in training and validation split_dataset = dataset["train"].train_test_split( test_size=self.val_split_fraction, seed=self.seed, shuffle=True ) split_dataset["val"] = split_dataset.pop("test") # rename the test split to val def tokenize(data: Dataset, index: int): yield self.tokenizer.encode(data[index]["text"], eos=True) optimize( fn=partial(tokenize, split_dataset["train"]), inputs=list(range(len(split_dataset["train"]))), output_dir=self.data_path_train, num_workers=min(64, os.cpu_count() - 1), chunk_bytes="200MB", ) optimize( fn=partial(tokenize, split_dataset["val"]), inputs=list(range(len(split_dataset["val"]))), output_dir=self.data_path_val, num_workers=min(8, os.cpu_count() - 1), chunk_bytes="200MB", ) def train_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader train_dataset = StreamingDataset( input_dir=self.data_path_train, item_loader=TokensLoader(block_size=self.seq_length), shuffle=True, ) train_dataloader = StreamingDataLoader( train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return train_dataloader def val_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader val_dataset = StreamingDataset( input_dir=self.data_path_val, item_loader=TokensLoader(block_size=self.seq_length), shuffle=True, ) val_dataloader = StreamingDataLoader( val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return val_dataloader ================================================ FILE: litgpt/data/prepare_slimpajama.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import os import time from pathlib import Path from litgpt.data.prepare_starcoder import DataChunkRecipe from litgpt.tokenizer import Tokenizer from litgpt.utils import CLI, extend_checkpoint_dir class SlimPajamaDataRecipe(DataChunkRecipe): is_generator = True def __init__(self, tokenizer: Tokenizer, chunk_size: int): super().__init__(chunk_size) self.tokenizer = tokenizer def prepare_structure(self, input_dir): files = Path(input_dir).rglob("*.zst") return [str(file) for file in files] def prepare_item(self, filepath): import zstandard as zstd with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: for row in f: text = json.loads(row)["text"] if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": continue # exclude the GitHub data since it overlaps with starcoder text_ids = self.tokenizer.encode(string=text, bos=False, eos=True) yield text_ids def prepare( input_dir: Path = Path("data/SlimPajama-627B/train"), output_dir: Path = Path("data/slimpajama/train"), tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), chunk_size: int = (2049 * 16384), fast_dev_run: bool = False, ) -> None: from litdata.processing.data_processor import DataProcessor tokenizer_path = extend_checkpoint_dir(tokenizer_path) tokenizer = Tokenizer(tokenizer_path) data_recipe = SlimPajamaDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size) data_processor = DataProcessor( input_dir=str(input_dir), output_dir=str(output_dir), fast_dev_run=fast_dev_run, num_workers=os.cpu_count(), num_downloaders=1, ) start_time = time.time() data_processor.run(data_recipe) elapsed_time = time.time() - start_time print(f"Time taken: {elapsed_time:.2f} seconds") if __name__ == "__main__": CLI(prepare) ================================================ FILE: litgpt/data/prepare_starcoder.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import time import traceback from pathlib import Path from litgpt.constants import _LITDATA_AVAILABLE from litgpt.tokenizer import Tokenizer from litgpt.utils import CLI, extend_checkpoint_dir if _LITDATA_AVAILABLE: from litdata.processing.data_processor import DataChunkRecipe else: DataChunkRecipe = object class StarcoderDataRecipe(DataChunkRecipe): is_generator = True def __init__(self, tokenizer: Tokenizer, chunk_size: int): super().__init__(chunk_size) self.tokenizer = tokenizer def prepare_structure(self, input_dir): files = Path(input_dir).rglob("*.parquet") return [str(file) for file in files] def prepare_item(self, item_metadata): import pyarrow.parquet as pq filepath = item_metadata start = time.time() try: parquet_file = pq.ParquetFile(filepath) # reduce RAM usage for batch in parquet_file.iter_batches(batch_size=8192, columns=["content"]): for text in batch.to_pandas()["content"]: yield self.tokenizer.encode(text, bos=False, eos=True) except Exception: print(traceback.format_exc()) print(f"Error reading {filepath}") return parquet_file.close() end = time.time() print(f"Took {end - start:.2f} seconds total", filepath) def prepare( input_dir: Path = Path("data/starcoderdata"), output_dir: Path = Path("data/starcoder"), tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), chunk_size: int = (2049 * 8192), fast_dev_run: bool = False, ) -> None: from litdata.processing.data_processor import DataProcessor tokenizer_path = extend_checkpoint_dir(tokenizer_path) tokenizer = Tokenizer(tokenizer_path) data_recipe = StarcoderDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size) data_processor = DataProcessor( input_dir=str(input_dir), output_dir=str(output_dir), fast_dev_run=fast_dev_run, num_workers=os.cpu_count(), num_downloaders=1, ) start_time = time.time() data_processor.run(data_recipe) elapsed_time = time.time() - start_time print(f"Time taken: {elapsed_time:.2f} seconds") if __name__ == "__main__": CLI(prepare) ================================================ FILE: litgpt/data/text_files.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import glob import os from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Optional from torch.utils.data import DataLoader from litgpt.data import DataModule from litgpt.tokenizer import Tokenizer @dataclass class TextFiles(DataModule): """The TextFile data module used for pretraining. Reads in text data from plaintext files contained in a data folder and provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ train_data_path: Path """The path to the data directory used for training that contains .txt files""" val_data_path: Optional[Path] = None """The path to the data directory used for validation that contains .txt files. Splits off data for validation from the training set if None.""" seed: int = 42 """The seed to use for shuffling the dataset.""" num_workers: int = 4 """The number of workers to use for data loading.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) def __post_init__(self) -> None: super().__init__() self.out_path_train = self.train_data_path / "train" if self.val_data_path is None: self.out_path_val = self.train_data_path / "val" else: self.out_path_val = Path(self.val_data_path) / "val" def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: from litdata import optimize from litdata.streaming import TokensLoader train_files = sorted(glob.glob(str(self.train_data_path / "*.txt"))) assert len(train_files) > 0, f"No .txt files found in train data {train_files}" if self.val_data_path is not None: self.val_data_path = Path(self.val_data_path) val_files = sorted(glob.glob(str(self.val_data_path / "*.txt"))) assert len(val_files) > 0, f"No .txt files found in validation data {val_files}" # train/test split. let's use only shard 0 for test split, rest train else: assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}" val_files, *train_files = train_files val_files = [val_files] # It's ok to use almost all CPUs here because this runs in a single process num_workers = os.cpu_count() - 1 use_workers = min(num_workers, len(train_files)) if not Path(self.out_path_train).is_dir(): validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=train_files, output_dir=str(self.out_path_train), num_workers=use_workers, chunk_bytes="50MB", item_loader=TokensLoader(block_size=self.max_seq_length), ) else: print( f"\nWarning: Preprocessed training data found in {self.out_path_train}." " For efficiency, reprocessing is skipped. If your text input has changed since" " the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger" f" reprocessing: `rm -rf {self.out_path_train}`\n" ) use_workers = min(num_workers, len(val_files)) if not Path(self.out_path_val).is_dir(): validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=val_files, output_dir=str(self.out_path_val), num_workers=use_workers, chunk_bytes="50MB", item_loader=TokensLoader(block_size=self.max_seq_length), ) else: print( f"\nWarning: Preprocessed validation data found in {self.out_path_val}." " For efficiency, reprocessing is skipped. If your text input has changed since" " the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger" f" reprocessing: `rm -rf {self.out_path_val}`\n" ) def train_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader train_dataset = StreamingDataset( input_dir=str(self.out_path_train), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, ) train_dataloader = StreamingDataLoader( train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return train_dataloader def val_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader val_dataset = StreamingDataset( input_dir=str(self.out_path_val), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, ) val_dataloader = StreamingDataLoader( val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return val_dataloader def tokenize(filename: str, tokenizer: Tokenizer): with open(filename, encoding="utf-8") as file: text = file.read() text = text.strip() yield tokenizer.encode(text, bos=True, eos=False) def validate_tokenizer(tokenizer: Tokenizer) -> None: if tokenizer is None: raise ValueError( "Tokenizer is None. If you are using this data module via `litgpt pretrain`, " "please provide a valid `--tokenizer_dir` path." ) ================================================ FILE: litgpt/data/tinyllama.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union from torch.utils.data import DataLoader from litgpt.data import DataModule from litgpt.tokenizer import Tokenizer @dataclass class TinyLlama(DataModule): """The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data. Provides training and validation streaming dataloaders that return batches of tokens. """ data_path: Union[str, Path] = Path("data/") """The path to the data directory, containing two folders 'slimpajama' and 'starcoder' which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md` for instructions. The path can also be a remote path (e.g., s3://).""" seed: int = 42 """The random seed for shuffling the dataset.""" num_workers: int = 8 """How many DataLoader processes to use for loading.""" use_starcoder: bool = True """Toggle for using Starcoder data.""" batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) def __post_init__(self): super().__init__() # Could be a remote path (s3://) or a local path self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train" self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val" self.required_paths = [self.slimpajama_train, self.slimpajama_val] if self.use_starcoder: self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder" self.required_paths += [self.starcoder_train] def connect( self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: for path in self.required_paths: if not path.startswith("s3://") and not Path(path).is_dir(): raise FileNotFoundError( "The data path for TinyLlama is expected to be the directory containing these subdirectories:" f" `slimpajama/train`, `slimpajama/val`, `starcoder`. The directory {path} does not exist." " Set it via `--data.data_path=...`" ) def train_dataloader(self) -> DataLoader: from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader slim_train_data = StreamingDataset( input_dir=self.slimpajama_train, item_loader=TokensLoader(block_size=self.seq_length), shuffle=True, drop_last=True, ) train_data = slim_train_data if self.use_starcoder: train_datasets = [ slim_train_data, StreamingDataset( input_dir=self.starcoder_train, item_loader=TokensLoader(block_size=self.seq_length), shuffle=True, drop_last=True, ), ] # Mix SlimPajama data and Starcoder data with these proportions: weights = (0.693584, 0.306416) train_data = CombinedStreamingDataset( datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False ) train_dataloader = StreamingDataLoader( train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return train_dataloader def val_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader val_dataset = StreamingDataset( input_dir=self.slimpajama_val, item_loader=TokensLoader(block_size=self.seq_length), shuffle=True, ) val_dataloader = StreamingDataLoader( val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return val_dataloader ================================================ FILE: litgpt/data/tinystories.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import glob import json import os from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Optional from torch.utils.data import DataLoader from tqdm import tqdm from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing from litgpt.data.text_files import validate_tokenizer from litgpt.tokenizer import Tokenizer @dataclass class TinyStories(DataModule): """The TinyStories data module: https://huggingface.co/datasets/roneneldan/TinyStories Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ data_path: Path = Path("data/tinystories") """The path to the data directory, containing two folders 'train' and 'val' which are the output of the preprocessing step.""" seed: int = 42 """The seed to use for shuffling the dataset.""" num_workers: int = 8 """The number of workers to use for the dataloaders.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) def __post_init__(self) -> None: super().__init__() self.data_path_train = self.data_path / "train" self.data_path_val = self.data_path / "val" def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: from litdata import TokensLoader, optimize download(self.data_path) files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json"))) assert len(files) > 0, f"No json files found in {files}" assert len(files) > 1, f"Expected at least two json files in {files}" # train/test split. let's use only shard 0 for test split, rest train val_file, *train_files = files num_workers = os.cpu_count() - 1 if not Path(self.data_path_train).is_dir(): validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=train_files, output_dir=str(self.data_path_train), num_workers=num_workers, chunk_bytes="200MB", item_loader=TokensLoader(), ) if not Path(self.data_path_val).is_dir(): validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=[val_file], output_dir=str(self.data_path_val), num_workers=1, # there's only 1 file chunk_bytes="200MB", item_loader=TokensLoader(), ) def train_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader train_dataset = StreamingDataset( input_dir=str(self.data_path_train), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, ) train_dataloader = StreamingDataLoader( train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return train_dataloader def val_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader val_dataset = StreamingDataset( input_dir=str(self.data_path_val), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, ) val_dataloader = StreamingDataLoader( val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) return val_dataloader def tokenize(filename: str, tokenizer: Tokenizer): with open(filename, encoding="utf-8") as f: data = json.load(f) global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"]) num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"]) local_rank = global_rank % num_workers for example in tqdm(data, position=local_rank): text = example["story"] text = text.strip() # get rid of leading/trailing whitespace tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS yield tokens _URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" def download(data_dir: Path): data_dir.mkdir(exist_ok=True, parents=True) data_tar = data_dir / "TinyStories_all_data.tar.gz" data_dir = data_dir / "TinyStories_all_data" shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) if shard_filenames: print(f"{data_dir} already exists, skipping unpacking...") return # download the TinyStories dataset, unless it's already downloaded download_if_missing(data_tar, _URL, stream=True, mode="wb") # unpack the tar.gz file into all the data shards (json files) data_dir.mkdir(exist_ok=False) tar_command = f"tar -xzf {data_tar} -C {data_dir}" print(tar_command) os.system(tar_command) shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) print(f"Number of shards: {len(shard_filenames)}") ================================================ FILE: litgpt/deploy/__init__.py ================================================ ================================================ FILE: litgpt/deploy/serve.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import sys from pathlib import Path from pprint import pprint from typing import Any, Dict, Literal, Optional import torch from litgpt.api import LLM from litgpt.constants import _JINJA2_AVAILABLE, _LITSERVE_AVAILABLE from litgpt.utils import auto_download_checkpoint if _LITSERVE_AVAILABLE: from litserve import LitAPI, LitServer from litserve.specs.openai import ChatCompletionRequest, OpenAISpec else: LitAPI, LitServer = object, object class BaseLitAPI(LitAPI): def __init__( self, checkpoint_dir: Path, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, precision: Optional[str] = None, temperature: float = 0.8, top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, api_path: Optional[str] = None, generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None, ) -> None: if not _LITSERVE_AVAILABLE: raise ImportError(str(_LITSERVE_AVAILABLE)) super().__init__(api_path=api_path) self.checkpoint_dir = checkpoint_dir self.quantize = quantize self.precision = precision self.temperature = temperature self.top_k = top_k self.max_new_tokens = max_new_tokens self.top_p = top_p self.devices = devices self.generate_strategy = generate_strategy def setup(self, device: str) -> None: if ":" in device: accelerator, device = device.split(":") device = f"[{int(device)}]" else: accelerator = device device = 1 print("Initializing model...", file=sys.stderr) self.llm = LLM.load(model=self.checkpoint_dir, distribute=None) self.llm.distribute( devices=self.devices, accelerator=accelerator, quantize=self.quantize, precision=self.precision, generate_strategy=self.generate_strategy or ("sequential" if self.devices is not None and self.devices > 1 else None), ) print("Model successfully initialized.", file=sys.stderr) def decode_request(self, request: Dict[str, Any]) -> Any: prompt = str(request["prompt"]) return prompt class SimpleLitAPI(BaseLitAPI): def __init__( self, checkpoint_dir: Path, quantize: Optional[str] = None, precision: Optional[str] = None, temperature: float = 0.8, top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, api_path: Optional[str] = None, generate_strategy: Optional[str] = None, ): super().__init__( checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices, api_path=api_path, generate_strategy=generate_strategy, ) def setup(self, device: str): super().setup(device) def predict(self, inputs: str) -> Any: output = self.llm.generate( inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens, ) return output def encode_response(self, output: str) -> Dict[str, Any]: # Convert the model output to a response payload. return {"output": output} class StreamLitAPI(BaseLitAPI): def __init__( self, checkpoint_dir: Path, quantize: Optional[str] = None, precision: Optional[str] = None, temperature: float = 0.8, top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, api_path: Optional[str] = None, generate_strategy: Optional[str] = None, ): super().__init__( checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices, api_path=api_path, generate_strategy=generate_strategy, ) def setup(self, device: str): super().setup(device) def predict(self, inputs: torch.Tensor) -> Any: yield from self.llm.generate( inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens, stream=True, ) def encode_response(self, output): for out in output: yield {"output": out} class OpenAISpecLitAPI(BaseLitAPI): def __init__( self, checkpoint_dir: Path, quantize: Optional[str] = None, precision: Optional[str] = None, temperature: float = 0.8, top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, api_path: Optional[str] = None, generate_strategy: Optional[str] = None, ): super().__init__( checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices, api_path=api_path, generate_strategy=generate_strategy, ) def setup(self, device: str): super().setup(device) if not _JINJA2_AVAILABLE: raise ImportError(str(_JINJA2_AVAILABLE)) from jinja2 import Template config_path = self.checkpoint_dir / "tokenizer_config.json" if not config_path.is_file(): raise FileNotFoundError(f"Tokenizer config file not found at {config_path}") with open(config_path, encoding="utf-8") as fp: config = json.load(fp) chat_template = config.get("chat_template", None) if chat_template is None: print("The tokenizer config does not contain chat_template, falling back to a default.") chat_template = "{% for m in messages %}{{ m.role }}: {{ m.content }}\n{% endfor %}Assistant: " self.chat_template = chat_template self.template = Template(self.chat_template) def decode_request(self, request: "ChatCompletionRequest") -> Any: # Apply chat template to request messages return self.template.render(messages=request.messages) def predict(self, inputs: str, context: dict) -> Any: # Extract parameters from context with fallback to instance attributes temperature = context.get("temperature") or self.temperature top_p = context.get("top_p", self.top_p) or self.top_p max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens # Run the model on the input and return the output. yield from self.llm.generate( inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True, ) def run_server( checkpoint_dir: Path, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, precision: Optional[str] = None, temperature: float = 0.8, top_k: int = 50, top_p: float = 1.0, max_new_tokens: int = 50, devices: int = 1, accelerator: str = "auto", port: int = 8000, stream: bool = False, openai_spec: bool = False, access_token: Optional[str] = None, api_path: Optional[str] = "/predict", timeout: int = 30, generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None, ) -> None: """Serve a LitGPT model using LitServe. Evaluate a model with the LM Evaluation Harness. Arguments: checkpoint_dir: The checkpoint directory to load the model from. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Optional precision setting to instantiate the model weights in. By default, this will automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. temperature: Temperature setting for the text generation. Value above 1 increase randomness. Values below 1 decrease randomness. top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel generated text but can also lead to more incoherent texts. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p max_new_tokens: The number of generation steps to take. devices: How many devices/GPUs to use. accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps". The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU. port: The network port number on which the model is configured to be served. stream: Whether to stream the responses. openai_spec: Whether to use the OpenAISpec and enable OpenAI-compatible API endpoints. When True, the server will provide `/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients, making it easy to integrate with existing applications that use the OpenAI API. access_token: Optional API token to access models with restrictions. api_path: The custom API path for the endpoint (e.g., "/my_api/classify"). timeout: Request timeout in seconds. Defaults to 30. generate_strategy: The generation strategy to use. The "sequential" strategy (default for devices > 1) allows running models that wouldn't fit in a single card by partitioning the transformer blocks across all devices and running them sequentially. "tensor_parallel" shards the model using tensor parallelism. If None (default for devices = 1), the model is not distributed. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI server = LitServer( api_class( checkpoint_dir=checkpoint_dir, quantize=quantize, precision=precision, temperature=temperature, top_k=top_k, top_p=top_p, max_new_tokens=max_new_tokens, devices=devices, api_path=api_path, generate_strategy=generate_strategy, ), spec=OpenAISpec() if openai_spec else None, accelerator=accelerator, devices=1, stream=stream, timeout=timeout, ) server.run(port=port, generate_client_file=False) ================================================ FILE: litgpt/eval/evaluate.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import os from pathlib import Path from pprint import pprint from typing import Optional, Union import torch from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint from litgpt.utils import auto_download_checkpoint, copy_config_files def prepare_results(results, save_filepath, print_results=True): from lm_eval.utils import make_table if print_results: print(make_table(results)) if "groups" in results: print(make_table(results, "groups")) json_result = json.dumps(results, indent=2, ensure_ascii=False, default=str) save_filepath.open("w", encoding="utf-8").write(json_result) def convert_and_evaluate( checkpoint_dir: Path, tasks: Optional[str] = None, out_dir: Optional[Path] = None, force_conversion: bool = False, num_fewshot: Optional[int] = None, batch_size: Union[int, str] = 1, device: Optional[str] = None, dtype: Optional[Union[str, torch.dtype]] = None, limit: Optional[float] = None, seed: int = 1234, save_filepath: Optional[Path] = None, access_token: Optional[str] = None, ) -> None: """Evaluate a model with the LM Evaluation Harness. Arguments: checkpoint_dir: Directory where the `lit_model.pth` and tokenizer files are located. out_dir: Directory in which to save the converted checkpoints for evaluation. Saves to `checkpoint_dir`/evaluate by default. force_conversion: Set to `True` to reconvert the model and override an existing model.pth from a previous evaluation call. tasks: CSV of task names to evaluate. Example: "hellaswag,truthfulqa_mc2,mmlu" num_fewshot: Number of examples in few-shot context. batch_size: Batch size configuration as positive integer value (default: 1), "auto", in the format 'auto:N', where 'auto:4' recomputes the batch size 4 times. device: Device to use for evaluation, for example, "cuda" or "cuda:0". limit: Limit on number of examples per task. seed: Random seed. save_filepath: The file where the results will be saved. Saves to `out_dir/results.json` by default. access_token: Optional API token to access models with restrictions. """ if tasks is None: from lm_eval.tasks import TaskManager taskm = TaskManager() print("\n".join(taskm.task_index.keys())) print( "\n\nTo evaluate multiple tasks, you can chain the task names " "listed above via a comma-separated list." "\nFor example: `--tasks 'hellaswag,truthfulqa_mc2,mmlu'`. " "\nTo search for a specific task, use `litgpt evaluate list | grep task_name`." ) return checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) if not (isinstance(batch_size, int) and batch_size > 0) and not ( isinstance(batch_size, str) and batch_size.startswith("auto") ): raise ValueError("batch_size must be a positive integer, 'auto', or in the format 'auto:N'.") from lm_eval import evaluator if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if out_dir is None: out_dir = checkpoint_dir / "evaluate" else: out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) save_filepath = out_dir / Path("results.json") if save_filepath is None else Path(save_filepath) model_path = out_dir / "pytorch_model.bin" if not model_path.exists() or force_conversion: copy_config_files(source_dir=checkpoint_dir, out_dir=out_dir) convert_lit_checkpoint(checkpoint_dir=checkpoint_dir, output_dir=out_dir) # Hack: LitGPT's conversion doesn't save a pickle file that is compatible to be loaded with # `torch.load(..., weights_only=True)`, which is a requirement in HFLM. # So we're `torch.load`-ing and `torch.save`-ing it again to work around this. state_dict = torch.load(out_dir / "model.pth") torch.save(state_dict, model_path) os.remove(out_dir / "model.pth") from lm_eval.models.huggingface import HFLM model = HFLM(pretrained=str(out_dir.resolve()), device=device, batch_size=batch_size, dtype=dtype) os.environ["TOKENIZERS_PARALLELISM"] = "false" results = evaluator.simple_evaluate( model=model, tasks=tasks.split(","), num_fewshot=num_fewshot, batch_size=batch_size, device=device, limit=limit, random_seed=seed, numpy_random_seed=seed, torch_random_seed=seed, ) prepare_results(results, save_filepath) ================================================ FILE: litgpt/finetune/__init__.py ================================================ ================================================ FILE: litgpt/finetune/adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses import math import os import time import warnings from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor from torch.utils.data import ConcatDataset, DataLoader from torchmetrics import RunningMean from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, auto_download_checkpoint, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, create_finetuning_performance_report, get_default_supported_precision, init_out_dir, instantiate_bnb_optimizer, instantiate_torch_optimizer, load_checkpoint, num_parameters, parse_devices, select_sft_generate_example, ) def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/adapter"), precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: Union[int, str] = 1, num_nodes: int = 1, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", logger_name: LoggerChoice = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: """Finetune a model using the Adapter method. Arguments: checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. devices: How many devices/GPUs to use. num_nodes: How many nodes the code is being run on. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) logger = choose_logger( logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval, log_args=dataclasses.asdict(log), ) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None if devices * num_nodes > 1: if quantize: raise NotImplementedError( "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" " when using the --quantize flag." ) strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric( devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger, plugins=plugins, ) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) def main( fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) mark_only_adapter_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) from bitsandbytes.nn import StableEmbedding old_embedding = model.transformer.wte model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) with torch.no_grad(): model.transformer.wte.weight.copy_(old_embedding.weight) model.transformer.wte = model.transformer.wte.to( device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) # strict=False because missing keys due to Adapter weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) train_time = time.perf_counter() token_counts = fit( fabric=fabric, model=model, optimizer=optimizer, scheduler=scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, devices=devices, num_nodes=num_nodes, checkpoint_dir=checkpoint_dir, out_dir=out_dir, train=train, eval=eval, data=data, ) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) # Final evaluation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") # Save the final Adapter checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.adapter" save_path.parent.mkdir(parents=True, exist_ok=True) save_adapter_checkpoint(fabric, model, save_path) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) def fit( fabric: L.Fabric, model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, data: DataModule, num_nodes: int = 1, ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length( ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check val_loss = "n/a" train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") step_count = 0 iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), } while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: optimizer.step() optimizer.zero_grad() scheduler.step() step_count += 1 token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() token_counts["raw_tokens_plus_prompt_template"] += ( batch["token_counts"]["raw_plus_prompt_template"].sum().item() ) token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() total_lengths += input_ids.numel() if iter_num % train.log_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths ) throughput.compute_and_log(step=iter_num) metrics = { "loss": loss, "iter": iter_num, "step": step_count, "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": token_counts["raw_tokens_plus_prompt_template"], "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" ) fabric.log_dict(metrics, step=iter_num) if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 val_loss_tensor = val_loss.detach().clone().to(fabric.device) val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) fabric.all_reduce(val_loss_tensor, reduce_op="mean") fabric.all_reduce(val_time_tensor, reduce_op="mean") fabric.print( f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" ) metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.adapter" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_checkpoint(fabric, model, checkpoint_file) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) save_prompt_style(data.prompt_style, checkpoint_file.parent) total_token_counts = {} for key in token_counts: total = fabric.all_reduce(token_counts[key], reduce_op="sum") total_token_counts[key] = total.item() return total_token_counts # FSDP has issues with `inference_mode` @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True ) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) for k, batch in enumerate(val_dataloader): if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() model.train() return val_loss # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " f"The model's supported context size (post-training) is {model.config.block_size}." ) def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) return train_dataloader, val_dataloader def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] longest_seq_length = max(lengths) longest_seq_ix = lengths.index(longest_seq_length) return longest_seq_length, longest_seq_ix def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: fabric.print(f"Saving adapter weights to {str(file_path)!r}") fabric.save(file_path, {"model": model}, filter={"model": adapter_filter}) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/finetune/adapter_v2.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses import math import os import time import warnings from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor from torch.utils.data import ConcatDataset, DataLoader from torchmetrics import RunningMean from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, auto_download_checkpoint, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, create_finetuning_performance_report, get_default_supported_precision, init_out_dir, instantiate_bnb_optimizer, instantiate_torch_optimizer, load_checkpoint, load_checkpoint_update, num_parameters, parse_devices, select_sft_generate_example, ) def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/adapter-v2"), precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: Union[int, str] = 1, num_nodes: int = 1, resume: Optional[bool] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", logger_name: LoggerChoice = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: """Finetune a model using the Adapter V2 method. Arguments: checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. devices: How many devices/GPUs to use. num_nodes: How many nodes the code is being run on. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) logger = choose_logger( logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval, log_args=dataclasses.asdict(log), ) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None if devices * num_nodes > 1: if quantize: raise NotImplementedError( "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" " when using the --quantize flag." ) strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric( devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger, plugins=plugins, ) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) def main( fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, resume: bool, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) mark_only_adapter_v2_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) from bitsandbytes.nn import StableEmbedding old_embedding = model.transformer.wte model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) with torch.no_grad(): model.transformer.wte.weight.copy_(old_embedding.weight) model.transformer.wte = model.transformer.wte.to( device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) if resume: # Finding last trace of adapter training try: resume = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=(lambda p: int(p.parent.name.split("-")[1]))) fabric.print(f"Resuming training from {resume}") load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False) resume = True except ValueError: fabric.print("No previous adapter found. Finetune from start.") resume = False load_checkpoint(fabric, model, checkpoint_path, strict=False) else: # strict=False because missing keys due to Adapter weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) mark_only_adapter_v2_as_trainable(model) train_time = time.perf_counter() token_counts = fit( fabric=fabric, model=model, optimizer=optimizer, scheduler=scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, devices=devices, resume=resume, num_nodes=num_nodes, checkpoint_dir=checkpoint_dir, out_dir=out_dir, train=train, eval=eval, data=data, ) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) # Final evaluation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") # Save the final Adapter checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.adapter_v2" save_path.parent.mkdir(parents=True, exist_ok=True) save_adapter_v2_checkpoint(fabric, model, save_path) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) def fit( fabric: L.Fabric, model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, resume: bool, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, data: DataModule, num_nodes: int = 1, ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length( ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check val_loss = "n/a" train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") step_count = 0 iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), } if not resume: try: iter_match = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=lambda p: int(p.parent.name.split("-")[1])) step_count = int(iter_match.parent.name.split("-")[1]) if iter_match else 0 except ValueError: step_count = 0 fabric.print(f"Starting at step count {step_count}") while step_count < max_steps and train_iterator.epoch < train.epochs: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: optimizer.step() optimizer.zero_grad() scheduler.step() step_count += 1 token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() token_counts["raw_tokens_plus_prompt_template"] += ( batch["token_counts"]["raw_plus_prompt_template"].sum().item() ) token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() total_lengths += input_ids.numel() if iter_num % train.log_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths ) throughput.compute_and_log(step=iter_num) metrics = { "loss": loss, "iter": iter_num, "step": step_count, "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": token_counts["raw_tokens_plus_prompt_template"], "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" ) fabric.log_dict(metrics, step=iter_num) if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 val_loss_tensor = val_loss.detach().clone().to(fabric.device) val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) fabric.all_reduce(val_loss_tensor, reduce_op="mean") fabric.all_reduce(val_time_tensor, reduce_op="mean") fabric.print( f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" ) metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.adapter_v2" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_v2_checkpoint(fabric, model, checkpoint_file) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) save_prompt_style(data.prompt_style, checkpoint_file.parent) total_token_counts = {} for key in token_counts: total = fabric.all_reduce(token_counts[key], reduce_op="sum") total_token_counts[key] = total.item() return total_token_counts # FSDP has issues with `inference_mode` @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True ) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) for k, batch in enumerate(val_dataloader): if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() model.train() return val_loss # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " f"The model's supported context size (post-training) is {model.config.block_size}." ) def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) return train_dataloader, val_dataloader def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] longest_seq_length = max(lengths) longest_seq_ix = lengths.index(longest_seq_length) return longest_seq_length, longest_seq_ix def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: fabric.print(f"Saving adapter v2 weights to {str(file_path)!r}") fabric.save(file_path, {"model": model}, filter={"model": adapter_filter}) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/finetune/full.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses import math import os import time from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy from torch.utils.data import ConcatDataset, DataLoader from torchmetrics import RunningMean from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.model import GPT, Block, Config from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, auto_download_checkpoint, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, create_finetuning_performance_report, find_resume_path, get_default_supported_precision, init_out_dir, instantiate_torch_optimizer, load_checkpoint, num_parameters, parse_devices, select_sft_generate_example, ) def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/full"), precision: Optional[str] = None, devices: Union[int, str] = 1, num_nodes: int = 1, resume: Union[bool, Literal["auto"], Path] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", logger_name: LoggerChoice = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: """Finetune a model. Arguments: checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". devices: How many devices/GPUs to use num_nodes: How many nodes the code is being run on. resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) logger = choose_logger( logger_name, out_dir, name=f"finetune-{config.name}", resume=bool(resume), log_interval=train.log_interval, log_args=dataclasses.asdict(log), ) if devices * num_nodes > 1: strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) def main( fabric: L.Fabric, devices: int, resume: Union[bool, Literal["auto"], Path], seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") model = fabric.setup(model) optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} resume = find_resume_path(resume, out_dir) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) else: load_checkpoint(fabric, state["model"], checkpoint_path) train_time = time.perf_counter() token_counts = fit( fabric=fabric, state=state, train_dataloader=train_dataloader, val_dataloader=val_dataloader, devices=devices, num_nodes=num_nodes, resume=resume, checkpoint_dir=checkpoint_dir, out_dir=out_dir, train=train, eval=eval, data=data, ) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) # Final evaluation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=state["iter_num"]) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") # Save the final checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) fabric.save(save_path, {"model": state["model"]}) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) def fit( fabric: L.Fabric, state: Dict, train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, resume: Union[bool, Literal["auto"], Path], checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, data: DataModule, num_nodes: int = 1, ) -> None: model = state["model"] optimizer = state["optimizer"] scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length( ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), } if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check val_loss = "n/a" initial_iter = state["iter_num"] max_steps = train.max_steps or float("inf") train_iterator = CycleIterator(train_dataloader) # resume data loader state by fast-forwarding through all seen batches if resume: resume_t0 = time.perf_counter() for resume_iter in range(initial_iter): next(train_iterator) if resume_iter % 1000 == 0: fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}") fabric.barrier() fabric.print( f"Resuming data loader finished. Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration" f" {initial_iter}." ) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() while state["step_count"] < max_steps: state["iter_num"] += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) # shift the targets such that output n predicts token n+1 loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: optimizer.step() optimizer.zero_grad() scheduler.step() state["step_count"] += 1 token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() token_counts["raw_tokens_plus_prompt_template"] += ( batch["token_counts"]["raw_plus_prompt_template"].sum().item() ) token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() if state["iter_num"] % train.log_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() metrics = { "loss": loss, "iter": state["iter_num"], "step": state["step_count"], "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": token_counts["raw_tokens_plus_prompt_template"], "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" ) fabric.log_dict(metrics, step=state["iter_num"]) if not is_accumulating and state["step_count"] % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 val_loss_tensor = val_loss.detach().clone().to(fabric.device) val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) fabric.all_reduce(val_loss_tensor, reduce_op="mean") fabric.all_reduce(val_time_tensor, reduce_op="mean") fabric.print( f"iter {state['iter_num']}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" ) metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} fabric.log_dict(metrics, step=state["iter_num"]) fabric.barrier() if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0: checkpoint_file = out_dir / f"step-{state['step_count']:06d}" / "lit_model.pth" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file.parent)!r}") fabric.save(checkpoint_file, state) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) save_prompt_style(data.prompt_style, checkpoint_file.parent) total_token_counts = {} for key in token_counts: total = fabric.all_reduce(token_counts[key], reduce_op="sum") total_token_counts[key] = total.item() return total_token_counts # FSDP has issues with `inference_mode` @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True ) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) for k, batch in enumerate(val_dataloader): if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() model.train() return val_loss @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " f"The model's supported context size (post-training) is {model.config.block_size}." ) def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) return train_dataloader, val_dataloader def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] longest_seq_length = max(lengths) longest_seq_ix = lengths.index(longest_seq_length) return longest_seq_length, longest_seq_ix def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/finetune/lora.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses import math import os import time import warnings from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.utilities import ThroughputMonitor from torch.utils.data import ConcatDataset, DataLoader from torchmetrics import RunningMean from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, auto_download_checkpoint, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, create_finetuning_performance_report, get_default_supported_precision, init_out_dir, instantiate_bnb_optimizer, instantiate_torch_optimizer, load_checkpoint, num_parameters, parse_devices, select_sft_generate_example, ) def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/lora"), precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: Union[int, str] = 1, num_nodes: int = 1, lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.05, lora_query: bool = True, lora_key: bool = False, lora_value: bool = True, lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, max_seq_length=None, max_time=None, ), log: LogArgs = LogArgs(), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), optimizer: Union[str, Dict] = "AdamW", logger_name: LoggerChoice = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: """Finetune a model using the LoRA method. Arguments: checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. devices: How many devices/GPUs to use. num_nodes: How many nodes the code is being run on. lora_r: The LoRA rank. lora_alpha: The LoRA alpha. lora_dropout: The LoRA dropout value. lora_query: Whether to apply LoRA to the query weights in attention. lora_key: Whether to apply LoRA to the key weights in attention. lora_value: Whether to apply LoRA to the value weights in attention. lora_projection: Whether to apply LoRA to the output projection in the attention block. lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. lora_head: Whether to apply LoRA to output head in GPT. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( checkpoint_dir / "model_config.yaml", lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_query=lora_query, lora_key=lora_key, lora_value=lora_value, lora_projection=lora_projection, lora_mlp=lora_mlp, lora_head=lora_head, ) precision = precision or get_default_supported_precision(training=True) logger = choose_logger( logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval, log_args=dataclasses.asdict(log), ) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None if devices * num_nodes > 1: if quantize: raise NotImplementedError( "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" " when using the --quantize flag." ) strategy = ModelParallelStrategy( parallelize_fn=parallelize_fn, data_parallel_size=devices * num_nodes, tensor_parallel_size=1, ) else: strategy = "auto" fabric = L.Fabric( devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger, plugins=plugins, ) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch( main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision ) def main( fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, precision: Optional[str] = None, ) -> None: validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) mark_only_lora_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) from bitsandbytes.nn import StableEmbedding old_embedding = model.transformer.wte model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) with torch.no_grad(): model.transformer.wte.weight.copy_(old_embedding.weight) model.transformer.wte = model.transformer.wte.to( device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) load_checkpoint(fabric, model, checkpoint_path, strict=False) train_time = time.perf_counter() token_counts = fit( fabric=fabric, model=model, optimizer=optimizer, scheduler=scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, devices=devices, num_nodes=num_nodes, checkpoint_dir=checkpoint_dir, out_dir=out_dir, train=train, eval=eval, data=data, ) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) # Final evaluation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") # Save the final LoRA checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, save_path) fabric.barrier() if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) merge_lora( checkpoint_dir=save_path.parent, pretrained_checkpoint_dir=checkpoint_dir, precision=precision, ) fabric.barrier() def fit( fabric: L.Fabric, model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, data: DataModule, num_nodes: int = 1, ) -> dict: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length( ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check val_loss = "n/a" train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") step_count = 0 iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() max_time = train.max_time or float("inf") token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), } while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: generate_example(fabric, model, tokenizer, eval, data) fabric.print(f"Number of epochs {train.epochs} reached, stopping training...") break if iter_t0 - total_t0 > max_time: generate_example(fabric, model, tokenizer, eval, data) fabric.print(f"Max time ({max_time / 60.0:.2f}m) reached, stopping training...") break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: optimizer.step() optimizer.zero_grad() scheduler.step() step_count += 1 token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() token_counts["raw_tokens_plus_prompt_template"] += ( batch["token_counts"]["raw_plus_prompt_template"].sum().item() ) token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() total_lengths += input_ids.numel() if iter_num % train.log_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths ) throughput.compute_and_log(step=iter_num) metrics = { "loss": loss, "iter": iter_num, "step": step_count, "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": token_counts["raw_tokens_plus_prompt_template"], "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" ) fabric.log_dict(metrics, step=iter_num) if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 val_loss_tensor = val_loss.detach().clone().to(fabric.device) val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) fabric.all_reduce(val_loss_tensor, reduce_op="mean") fabric.all_reduce(val_time_tensor, reduce_op="mean") fabric.print( f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" ) metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) save_prompt_style(data.prompt_style, checkpoint_file.parent) total_token_counts = {} for key in token_counts: total = fabric.all_reduce(token_counts[key], reduce_op="sum") total_token_counts[key] = total.item() return total_token_counts # FSDP has issues with `inference_mode` @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True ) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) for k, batch in enumerate(val_dataloader): if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() model.train() return val_loss @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " f"The model's supported context size (post-training) is {model.config.block_size}." ) def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) return train_dataloader, val_dataloader def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] longest_seq_length = max(lengths) longest_seq_ix = lengths.index(longest_seq_length) return longest_seq_length, longest_seq_ix def parallelize_fn(model, device_mesh, activation_checkpointing=True): from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper if activation_checkpointing: model.transformer.h = torch.nn.ModuleList( [checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h] ) dp_mesh = device_mesh["data_parallel"] for m in reversed(list(model.modules())): if ( (isinstance(m, torch.nn.Linear) and m.weight.requires_grad) or isinstance(m, CheckpointWrapper) or isinstance(m, Block) ): fully_shard(m, mesh=dp_mesh) fully_shard(model, mesh=dp_mesh) return model def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: cpu_state_dict = {} sharded_sd = model.state_dict() for param_name, param in sharded_sd.items(): if "lora_" not in param_name: continue if param.is_cpu: param = param.to(fabric.device) if hasattr(param, "_local_tensor"): param = param.full_tensor() if fabric.is_global_zero: cpu_state_dict[param_name] = param.cpu() fabric.barrier() if fabric.is_global_zero: torch.save({"model": cpu_state_dict}, file_path) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/finetune/lora_legacy.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses import math import os import time import warnings from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor from torch.utils.data import ConcatDataset, DataLoader from torchmetrics import RunningMean from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, auto_download_checkpoint, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, create_finetuning_performance_report, get_default_supported_precision, init_out_dir, instantiate_bnb_optimizer, instantiate_torch_optimizer, load_checkpoint, num_parameters, parse_devices, select_sft_generate_example, ) def setup( checkpoint_dir: Path, out_dir: Path = Path("out/finetune/lora"), precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: Union[int, str] = 1, num_nodes: int = 1, lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.05, lora_query: bool = True, lora_key: bool = False, lora_value: bool = True, lora_projection: bool = False, lora_mlp: bool = False, lora_head: bool = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, max_seq_length=None, ), log: LogArgs = LogArgs(), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), optimizer: Union[str, Dict] = "AdamW", logger_name: LoggerChoice = "csv", seed: int = 1337, access_token: Optional[str] = None, ) -> None: """Finetune a model using the LoRA method. Arguments: checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. devices: How many devices/GPUs to use. num_nodes: How many nodes the code is being run on. lora_r: The LoRA rank. lora_alpha: The LoRA alpha. lora_dropout: The LoRA dropout value. lora_query: Whether to apply LoRA to the query weights in attention. lora_key: Whether to apply LoRA to the key weights in attention. lora_value: Whether to apply LoRA to the value weights in attention. lora_projection: Whether to apply LoRA to the output projection in the attention block. lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block. lora_head: Whether to apply LoRA to output head in GPT. data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. """ checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token) pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( checkpoint_dir / "model_config.yaml", lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_query=lora_query, lora_key=lora_key, lora_value=lora_value, lora_projection=lora_projection, lora_mlp=lora_mlp, lora_head=lora_head, ) precision = precision or get_default_supported_precision(training=True) logger = choose_logger( logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval, log_args=dataclasses.asdict(log), ) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None if devices * num_nodes > 1: if quantize: raise NotImplementedError( "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1" " when using the --quantize flag." ) strategy = FSDPStrategy( auto_wrap_policy={torch.nn.Linear}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric( devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger, plugins=plugins, ) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes) def main( fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: validate_args(train, eval) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) mark_only_lora_as_trainable(model) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") model = fabric.setup_module(model) if isinstance(fabric.strategy.precision, BitsandbytesPrecision): optimizer = instantiate_bnb_optimizer(optimizer, model.parameters()) from bitsandbytes.nn import StableEmbedding old_embedding = model.transformer.wte model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim) with torch.no_grad(): model.transformer.wte.weight.copy_(old_embedding.weight) model.transformer.wte = model.transformer.wte.to( device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) else: optimizer = instantiate_torch_optimizer(optimizer, model.parameters()) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) # strict=False because missing keys due to LoRA weights not contained in state dict load_checkpoint(fabric, model, checkpoint_path, strict=False) train_time = time.perf_counter() token_counts = fit( fabric=fabric, model=model, optimizer=optimizer, scheduler=scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, devices=devices, num_nodes=num_nodes, checkpoint_dir=checkpoint_dir, out_dir=out_dir, train=train, eval=eval, data=data, ) training_time = time.perf_counter() - train_time output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type) fabric.print(output) # Final evaluation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") # Save the final LoRA checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, save_path) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) merge_lora(checkpoint_dir=save_path.parent) def fit( fabric: L.Fabric, model: GPT, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs, data: DataModule, num_nodes: int = 1, ) -> dict: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length( ConcatDataset([train_dataloader.dataset, val_dataloader.dataset]) ) model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader))) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check val_loss = "n/a" train_iterator = CycleIterator(train_dataloader) throughput = ThroughputMonitor(fabric, window_size=50) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) max_steps = train.max_steps or float("inf") step_count = 0 iter_num = 0 total_lengths = 0 total_t0 = time.perf_counter() token_counts = { "raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long), "raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long), } while step_count < max_steps: iter_num += 1 iter_t0 = time.perf_counter() batch = next(train_iterator) if train_iterator.epoch >= train.epochs: break input_ids, targets = batch["input_ids"], batch["labels"] is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: optimizer.step() optimizer.zero_grad() scheduler.step() step_count += 1 token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item() token_counts["raw_tokens_plus_prompt_template"] += ( batch["token_counts"]["raw_plus_prompt_template"].sum().item() ) token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel() total_lengths += input_ids.numel() if iter_num % train.log_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths ) throughput.compute_and_log(step=iter_num) metrics = { "loss": loss, "iter": iter_num, "step": step_count, "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "tokens": token_counts["raw_tokens_plus_prompt_template"], "total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size, "learning_rate": scheduler.get_last_lr()[0], } if isinstance(val_loss, torch.Tensor): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" ) fabric.log_dict(metrics, step=iter_num) if not is_accumulating and step_count % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, eval) generate_example(fabric, model, tokenizer, eval, data) t1 = time.perf_counter() - t0 val_loss_tensor = val_loss.detach().clone().to(fabric.device) val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32) fabric.all_reduce(val_loss_tensor, reduce_op="mean") fabric.all_reduce(val_time_tensor, reduce_op="mean") fabric.print( f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms" ) metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)} fabric.log_dict(metrics, step=iter_num) fabric.barrier() if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) save_prompt_style(data.prompt_style, checkpoint_file.parent) total_token_counts = {} for key in token_counts: total = fabric.all_reduce(token_counts[key], reduce_op="sum") total_token_counts[key] = total.item() return total_token_counts # FSDP has issues with `inference_mode` @torch.no_grad() def validate( fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True ) -> torch.Tensor: if verbose: fabric.print("Validating ...") model.eval() losses = torch.zeros(min(len(val_dataloader), eval.max_iters)) for k, batch in enumerate(val_dataloader): if k >= eval.max_iters: break input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0) val_loss = losses.mean() model.train() return val_loss @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) model.eval() max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id ) model.clear_kv_cache() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") else: print( f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) " f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. " f"The model's supported context size (post-training) is {model.config.block_size}." ) def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): # linear warmup followed by cosine annealing scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps)) return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps]) def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) return train_dataloader, val_dataloader def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: # find out the minimum max_seq_length required during fine-tuning (saves memory!) lengths = [len(d["input_ids"]) for d in data] longest_seq_length = max(lengths) longest_seq_ix = lengths.index(longest_seq_length) return longest_seq_length, longest_seq_ix def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: fabric.print(f"Saving LoRA weights to {str(file_path)!r}") fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if not train.epochs and not train.max_steps: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/generate/__init__.py ================================================ ================================================ FILE: litgpt/generate/adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time import warnings from pathlib import Path from pprint import pprint from typing import Literal, Optional import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from litgpt import PromptStyle, Tokenizer from litgpt.adapter import GPT, Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.generate.base import generate from litgpt.prompts import has_prompt_style, load_prompt_style from litgpt.utils import ( check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, lazy_load, ) def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", input: str = "", sys_prompt: Optional[str] = None, adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, precision: Optional[str] = None, ) -> None: """For models finetuned with `litgpt finetune_adapter`. Generates a response based on a given instruction and an optional input. This script will only work with checkpoints from the instruction-tuned adapter model. See ``litgpt.finetune.adapter``. Args: checkpoint_dir: The path to the checkpoint folder with pretrained model weights. prompt: The prompt/instruction (Alpaca style). input: Optional input (Alpaca style). sys_prompt: Optional system prompt. adapter_path: Path to the checkpoint with trained adapter weights, which are the output of ``litgpt.finetune.adapter``. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. precision: Indicates the Fabric precision setting to use. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() t0 = time.perf_counter() checkpoint = lazy_load(checkpoint_path) adapter_checkpoint = lazy_load(adapter_path) checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) model.load_state_dict(checkpoint) fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) model = fabric.setup(model) L.seed_everything(1234) t0 = time.perf_counter() y = generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id ) t = time.perf_counter() - t0 output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() fabric.print(output) tokens_generated = y.size(0) - prompt_length fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/adapter_v2.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time import warnings from pathlib import Path from pprint import pprint from typing import Literal, Optional import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from litgpt import PromptStyle, Tokenizer from litgpt.adapter_v2 import GPT, Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.generate.base import generate from litgpt.prompts import has_prompt_style, load_prompt_style from litgpt.utils import ( check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, lazy_load, ) def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", input: str = "", sys_prompt: Optional[str] = None, adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, precision: Optional[str] = None, ) -> None: """For models finetuned with `litgpt finetune adapter_v2`. Generates a response based on a given instruction and an optional input. This script will only work with checkpoints from the instruction-tuned adapter v2 model. See ``litgpt.finetune.adapter_v2``. Args: checkpoint_dir: The path to the checkpoint folder with pretrained model weights. prompt: The prompt/instruction (Alpaca style). input: Optional input (Alpaca style). sys_prompt: Optional system prompt. adapter_path: Path to the checkpoint with trained adapter weights, which are the output of ``litgpt.finetune.adapter_v2``. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. precision: Indicates the Fabric precision setting to use. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() t0 = time.perf_counter() checkpoint = lazy_load(checkpoint_path) adapter_checkpoint = lazy_load(adapter_path) checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) model.load_state_dict(checkpoint) fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) model = fabric.setup(model) L.seed_everything(1234) t0 = time.perf_counter() y = generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id ) t = time.perf_counter() - t0 output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() fabric.print(output) tokens_generated = y.size(0) - prompt_length fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/base.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time import warnings from pathlib import Path from pprint import pprint from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union import lightning as L import torch import torch._dynamo.config import torch._inductor.config from lightning.fabric.plugins import BitsandbytesPrecision from litgpt.config import Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.model import GPT from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, ) def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: if torch._dynamo.is_compiling(): # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly distribution = torch.empty_like(probs).exponential_(1) return torch.argmax(probs / distribution, dim=-1, keepdim=True) return torch.multinomial(probs, num_samples=1) def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Example: # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0] # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7 sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # Keep at least 1 token always to prevent the case where no token is selected # In this case the most probable one is always kept sorted_indices_to_remove[-1:] = 0 indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, float("-inf")) return logits def sample( logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0 ) -> torch.Tensor: if top_p < 0.0 or top_p > 1.0: raise ValueError(f"top_p must be in [0, 1], got {top_p}") logits = logits[0, -1] # optionally crop the logits to only the top k options if top_k is not None: v, i = torch.topk(logits, min(top_k, logits.size(-1))) # do not use `torch.where` as in nanogpt because it will repeat top-k collisions logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) # optionally scale the logits and sample from a probability distribution if temperature > 0.0 or top_p > 0.0: if temperature > 0.0: logits = logits / temperature # optionally crop the logits to smallest set of logits with a cumulative probability above top_p if top_p < 1.0: logits = sample_top_p(logits, top_p) probs = torch.nn.functional.softmax(logits, dim=-1) return multinomial_num_samples_1(probs) return torch.argmax(logits, dim=-1, keepdim=True) def next_token( model: GPT, input_pos: torch.Tensor, x: torch.Tensor, input_pos_maxp1: Optional[int] = None, **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) _next = sample(logits, **sample_kwargs).to(dtype=torch.int64) return _next def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor: assert len(logits) == len(kwargs), "logits and kwargs must have the same length." return torch.stack( [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0 ) def batched_next_token( model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]] ) -> torch.Tensor: # Where: # input_pos is a 1d tensor of shape [seq_length...] # x is context tokens to add to the kvcache. # For prefill, x is a 2d tensor of shape [batch_size, prompt_length]. # For subsequent tokens, x is a 2d tensor of shape [batch_size, 1]. # kwargs is a list of dictionaries, each containing the keyword arguments for the sample function. # If one dictionary is passed, it's repeated for each sample in the batch. # In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length]. # That way, we can support prompts of different sizes. # This means making the rope cache and kvcache forward() work with batches. Currently, they do not. # This is relatively complicated, given the current implementation. It will require some rewriting. # Relevant thread: https://discuss.pytorch.org/t/batched-index-select/9115 # We will also need the same with tensor.index_copy_(). These do not work for batches, and the replacement # is somewhat nontrivial. Until then, we can only accept prompts that are all the same length. # After this problem is resolved, there will be another problem. That being, continuous batched prefill. # If you have any ideas on this, let me know. I don't think that padding input_pos is viable. _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0) # Run the model on the batch. logits_stack = model(x, input_pos) # Unbind the logits stack into a list of logits. logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0) logits_list = [l.unsqueeze(0) for l in logits_list] # Return the next token for each sample in the batch. return batched_sample(logits_list, kwargs=_kwargs) @torch.inference_mode() def generate_fn( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, stop_tokens: Tuple[List[int], ...] = (), include_prompt: bool, include_eos: bool, ) -> Iterator[torch.Tensor]: """ Generates tokens for a single prompt. Args: model: The model to use. prompt: The tokenized prompt to generate from. max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. temperature: The temp to pass to sample(). top_k: The top_k to pass to sample(). top_p: The top_p to pass to sample(). stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. include_eos: Whether to output the stop tokens if generation stops early. """ prompt_size = prompt.size(0) device = prompt.device assert max_returned_tokens > prompt_size, ( f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}." ) if model.max_seq_length < max_returned_tokens - 1: raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") # Yield the prompt if include_prompt is True if include_prompt: yield prompt stop_progress = [0] * len(stop_tokens) yielded_idx = 0 # Generate output tokens. # The first token generated is the prefill token. # The input_pos for this token is the width of the entire prompt. # For subsequent iterations, it's the index in the context for the token that we're generating. tokens = [] token = prompt prefill_token = True input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) # input_pos_maxp1 introduces data-dependent shapes and control flow. # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token token = next_token( model, input_pos, token.view(1, -1), input_pos_maxp1=input_pos_maxp1, temperature=temperature, top_k=top_k, top_p=top_p, ) tokens.append(token) int_token = token.item() # Check for stop sequences # For each stop sequence, we keep a running total of how many are matched in stop_progress. # If the current token matches the next token in the stop sequence, we increment the # running total and hold off on yielding the token. for i, seq in enumerate(stop_tokens): if int_token == seq[stop_progress[i]]: stop_progress[i] += 1 if stop_progress[i] == len(seq): if include_eos: yield from tokens[yielded_idx:] return else: stop_progress[i] = 0 # Yield tokens that are not part of a stop sequence in progress. # If there are no stop sequences, then that's all of them. if stop_tokens: safe_idx = len(tokens) - max(stop_progress) else: safe_idx = current_idx + 1 # include the token just generated if yielded_idx < safe_idx: y_tokens = tokens[yielded_idx:safe_idx] yield from y_tokens yielded_idx = safe_idx # Update input_pos for the next iteration. if prefill_token: prefill_token = False input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) if input_pos_maxp1 is not None: input_pos_maxp1 += 1 # Yield any remaining tokens if yielded_idx < len(tokens): yield from tokens[yielded_idx:] # TODO: Make include_eos work. # TODO: Rewrite unbatched generate_fn to use batched_generate_fn. @torch.inference_mode() def batched_generate_fn( model: GPT, prompts: torch.Tensor, max_returned_tokens: int, *, sample_args: Union[list[dict], dict], stop_tokens: Tuple[List[int], ...] = (), include_prompt: bool, include_eos: bool, ) -> Iterator[list[Union[torch.Tensor, None]]]: """ Generates tokens for a batch of prompts. Args: model: The model to use. prompts: A 2D tensor of shape [batch_size, prompt_length]. max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. include_eos: Whether to output the stop tokens if generation stops early. Yields: A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch. """ if prompts.ndim == 1: prompts = prompts.unsqueeze(0) assert prompts.ndim == 2, "Prompts must be a 2D tensor." batch_size = prompts.size(0) max_prompt_size = prompts.size(1) device = prompts.device if isinstance(sample_args, dict): sample_args = [sample_args] * len(prompts) else: assert len(sample_args) == batch_size, "sample_args must have the length as the batch size." # TODO: This check (and the one in generate_fn) is not sufficient. We do the proper checks in LLM.generate(). assert max_returned_tokens > max_prompt_size, ( f"Not enough space for {max_prompt_size} prompt tokens in a context length of {max_returned_tokens}." ) if model.max_seq_length < max_returned_tokens - 1: raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") # Yield the prompts if include_prompt is True if include_prompt: # TODO: Prompt length is padded, but they shouldn't all be the same length. for i in range(max_prompt_size): yield [prompt[i].view(-1) for prompt in prompts] stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)] stop_idxes = [-1] * batch_size yielded_idx = 0 # Generate output tokens. # The first token generated is the prefill token. # The input_pos for this token is the width of the entire prompt. # For subsequent iterations, it's the index in the context for the token that we're generating. token_lists = [[] for _ in range(batch_size)] tokens: torch.Tensor = prompts prefill_token = True input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64) for current_idx in range(max_returned_tokens - max_prompt_size): # Generate the next token for each prompt in the batch. # This is of shape [batch_size, 1]. tokens = batched_next_token(model, input_pos, tokens, sample_args) for i in range(batch_size): token_lists[i].append(tokens[i]) int_tokens = [token.item() for token in tokens] # Check for stop sequences # For each stop sequence, we keep a running total of how many are matched in stop_progress. # If the current token matches the next token in the stop sequence, we increment the # running total and hold off on yielding the token. for batch_idx, int_token in enumerate(int_tokens): if stop_idxes[batch_idx] != -1: continue for seq_idx, seq in enumerate(stop_tokens): seq_pos = stop_progresses[batch_idx][seq_idx] if seq_pos >= len(seq): continue if int_token == seq[seq_pos]: stop_progresses[batch_idx][seq_idx] += 1 if stop_progresses[batch_idx][seq_idx] == len(seq): stop_idxes[batch_idx] = current_idx else: stop_progresses[batch_idx][seq_idx] = 0 # Yield tokens that are not part of a stop sequence in progress. # If there are no stop sequences, then that's all of them. if len(stop_tokens) != 0: safe_idxes = [len(token_lists[i]) - max(stop_progresses[i]) for i in range(batch_size)] else: safe_idxes = [current_idx + 1] # include the token just generated safe_idx = min(safe_idxes) if yielded_idx < safe_idx: for idx in range(yielded_idx, safe_idx): y_tokens = [ token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size) ] if all(y is None for y in y_tokens): return yield y_tokens yielded_idx = safe_idx # Update input_pos for the next iteration. if prefill_token: prefill_token = False # TODO: Make the model support a batched input_pos of shape [batch_size, 1]. # The kvcache has been fixed, but the rope cache is still broken. input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) # Yield any remaining tokens max_token_lists = max(len(l) for l in token_lists) if yielded_idx < max_token_lists: for idx in range(yielded_idx, max_token_lists): y_tokens = [ token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size) ] if all(y is None for y in y_tokens): return yield y_tokens return @torch.inference_mode() def generate( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, eos_id: Optional[int] = None, include_prompt: bool = True, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. The implementation of this function is modified from A. Karpathy's nanoGPT. Args: model: The model to use. prompt: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p eos_id: If specified, stop generating any more token once the token is triggered. include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output. """ token_list = list( generate_fn( include_prompt=include_prompt, include_eos=True, model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=(([eos_id],) if eos_id is not None else ()), ) ) return torch.cat(token_list) if not len(token_list) == 0 else torch.Tensor() @torch.inference_mode() def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", *, sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, precision: Optional[str] = None, compile: bool = False, ) -> None: """Default generation option. Generates text samples based on a pre-trained model and tokenizer. Args: checkpoint_dir: The checkpoint directory to load. prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. compile: Whether to compile the model. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() if compile: torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True global next_token next_token = torch.compile(next_token, mode="reduce-overhead") model = fabric.setup_module(model) t0 = time.perf_counter() load_checkpoint(fabric, model, checkpoint_path) fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 for block in model.transformer.h: block.attn.kv_cache.reset_parameters() fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/full.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time import warnings from pathlib import Path from pprint import pprint from typing import Literal, Optional import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision from litgpt import GPT, Config, PromptStyle, Tokenizer from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.generate.base import generate from litgpt.prompts import has_prompt_style, load_prompt_style from litgpt.utils import ( check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, ) def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", input: str = "", sys_prompt: Optional[str] = None, finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, precision: Optional[str] = None, ) -> None: """For models finetuned with `litgpt finetune_full`. Generates a response based on a given instruction and an optional input. This script will only work with checkpoints from the instruction-tuned model. See ``litgpt.finetune.full``. Args: checkpoint_dir: The path to the checkpoint folder with pretrained model weights. prompt: The prompt/instruction (Alpaca style). input: Optional input (Alpaca style). sys_prompt: Optional system prompt. finetuned_path: Path to the checkpoint with trained weights, which are the output of ``litgpt.finetune.full``. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. precision: Indicates the Fabric precision setting to use. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = finetuned_path check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() model = fabric.setup(model) t0 = time.perf_counter() load_checkpoint(fabric, model, checkpoint_path) fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) L.seed_everything(1234) t0 = time.perf_counter() y = generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id ) t = time.perf_counter() - t0 output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() fabric.print(output) tokens_generated = y.size(0) - prompt_length fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/sequentially.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import itertools import logging import re import sys import time import warnings from collections import OrderedDict from functools import partial from pathlib import Path from pprint import pprint from typing import List, Literal, Optional, Type import lightning as L import torch from lightning.fabric.accelerators import CUDAAccelerator from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from tqdm import tqdm import litgpt.generate.base as generate_base from litgpt.config import Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.model import GPT, Block, build_mask_cache from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, ) @torch.inference_mode() def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int): if model.config.n_layer < devices: raise ValueError( f"The number of layers in the model must be larger than the number of devices, but got" f" n_layer={model.config.n_layer} and devices={devices}." ) # Dictates where each block should be instantiated mapping = layer_to_device( model, chunk_on=Block, chunk_sizes=chunk_sizes(model.config.n_layer, devices), ) num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)} # materialize each block on the appropriate device with tqdm(total=len(mapping), desc="Moving submodules") as pbar: for path, target_index in mapping.items(): submodule = model.get_submodule(path) target_device = torch.device(root.type, target_index) pbar.set_description(f"Moving {path!r} to {target_device}") pbar.update(1) # submodules loaded by the checkpoint will be on CPU (if no quantization). move them replace_device(submodule, replace=torch.device("cpu"), by=target_device) # in case the checkpoint was partial, materialize leftover metas _materialize_meta_tensors(submodule, target_device) # and build the kv cache submodule.attn.kv_cache = submodule.attn.build_kv_cache( 1, max_seq_length, model.rope_cache_length(), target_device ) # rebuild odd ends with root: model.max_seq_length = max_seq_length # the rope cache which is on meta device model.cos, model.sin = model.rope_cache() # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers model.mask_cache = build_mask_cache(max_seq_length) # and everything that is not a block in the root _materialize_meta_tensors(model, root) replace_device(model, replace=torch.device("cpu"), by=root) if devices > 1: # install hooks to move layer inputs/output between devices for layer_num, (path, target_index) in enumerate(mapping.items()): submodule = model.get_submodule(path) if layer_num >= num_layers_per_device[target_index]: # we need to move the block input on the boundaries between devices # and also on every non-root device because the RoPE and mask cache is shared # TODO: the second case could be optimized and then we would only need this hook for # `layer_num in [layers_per_rank * i - 1 for i in range(1, devices + 1)]` target_device = torch.device(root.type, target_index) submodule.register_forward_pre_hook(partial(move_block_input, target_device)) if layer_num == model.config.n_layer - 1: submodule.register_forward_hook(partial(move_block_output, root)) return model def chunk_sizes(num_units: int, devices: int) -> List[int]: cs = num_units // devices k = devices * (cs + 1) - num_units return [cs] * k + [cs + 1] * (devices - k) def layer_to_device( module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_sizes: List[int], ) -> "OrderedDict[str, int]": """Create a mapping from layer (block) to device.""" # this assumes that the definition order is the same as the execution order hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)] if sum(chunk_sizes) != len(hits): raise ValueError(f"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}") _devices = [[d] * cs for d, cs in enumerate(chunk_sizes)] devices = [d for lst in _devices for d in lst] return OrderedDict(zip(hits, devices)) def move_block_input(device: torch.device, module: torch.nn.Module, ins): """``forward_pre_hook`` to move a Block's input before forward.""" # during inference, none of the inputs are None: x, cos, sin, mask, input_pos return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins) def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor: """``forward_hook`` to move a Block's output after forward.""" return outs.to(device) def replace_device(module: torch.nn.Module, replace: torch.device, by: torch.device) -> torch.nn.Module: for name, submodule in module.named_modules(): tensors = dict( itertools.chain(submodule.named_parameters(recurse=False), submodule.named_buffers(recurse=False)) ) if not tensors: continue devices = {t.device for t in tensors.values()} if len(devices) != 1: # since this is using `submodule.to`, different devices in the same submodule is a problem path_to_device = {f"{name}.{p}": t.device for p, t in tensors.items()} raise ValueError(f"Found multiple devices: {path_to_device}") if devices.pop() == replace: submodule.to(by) return module @torch.inference_mode() def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", *, sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None, precision: Optional[str] = None, compile: bool = False, ) -> None: """Generation script that partitions layers across devices to be run sequentially. Generates text samples based on a pre-trained model and tokenizer. Args: checkpoint_dir: The checkpoint directory to load. prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. compile: Whether to compile the model. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None: if compile: raise NotImplementedError # untested if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes").setLevel(logging.DEBUG) plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, accelerator="cuda", plugins=plugins) total_devices = CUDAAccelerator.auto_device_count() print(f"Using {total_devices} devices", file=sys.stderr) check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) t0 = time.perf_counter() state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") # TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware model.load_state_dict(state_dict, assign=True) print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) model = fabric.setup_module(model, move_to_device=False) t0 = time.perf_counter() model = sequential(model, fabric.device, max_returned_tokens, total_devices) print(f"Time to sequential-ize the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) if compile: # TODO: raises an internal compile AssertionError caused by fabric.strategy.precision.forward_context raise NotImplementedError # silence developer warning on nightly builds # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/ir.py#L4166 pattern = re.compile(".*DeviceCopy in input program.*") logging.getLogger("torch._inductor.utils").addFilter(lambda record: not pattern.search(record.getMessage())) torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True # cannot use cudagraphs because it doesn't support multiple device indices # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375 generate_base.next_token = torch.compile(generate_base.next_token) L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( model=model, prompt=encoded, max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 for block in model.transformer.h: block.attn.kv_cache.reset_parameters() print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/speculative_decoding.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys import time import warnings from pathlib import Path from pprint import pprint from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple import lightning as L import torch import torch._dynamo.config import torch._inductor.config import torch.nn.functional as F from lightning.fabric.plugins import BitsandbytesPrecision from litgpt.config import Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.generate.base import multinomial_num_samples_1, next_token, sample_top_p from litgpt.model import GPT from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_file_size_on_cpu_and_warn, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, ) def sample( logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, apply_softmax: bool = True, ) -> torch.Tensor: if top_p < 0.0 or top_p > 1.0: raise ValueError(f"top_p must be in [0, 1], got {top_p}") logits = logits[0, -1] # optionally crop the logits to only the top k options if top_k is not None: v, i = torch.topk(logits, min(top_k, logits.size(-1))) # do not use `torch.where` as in nanogpt because it will repeat top-k collisions fill_value = float("-inf") if apply_softmax else float(0) logits = torch.full_like(logits, fill_value).scatter_(-1, i, v) # optionally scale the logits and sample from a probability distribution if temperature > 0.0 or top_p > 0.0: if temperature > 0.0: logits = logits / temperature # optionally crop the logits to smallest set of logits with a cumulative probability above top_p if top_p < 1.0: logits = sample_top_p(logits, top_p) probs = F.softmax(logits, dim=-1) if apply_softmax else logits return multinomial_num_samples_1(probs), probs return torch.argmax(logits, dim=-1, keepdim=True), F.softmax(logits, dim=-1) def speculative_decoding( draft_model: GPT, target_model: GPT, token: torch.Tensor, input_pos: torch.Tensor, input_pos_maxp1: int, speculative_k: int, **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: """Performs speculative decoding using a draft and a target model. This implements the speculative decoding algorithm from "Fast Inference from Transformers via Speculative Decoding" (https://arxiv.org/pdf/2211.17192). The core idea is to: 1. Use a faster draft model to predict multiple tokens ahead 2. Verify those predictions with the slower but more accurate target model 3. Accept tokens where the target model agrees with high probability 4. Reject and resample tokens where there is a disagreement This allows leveraging a smaller/faster model to speed up generation while maintaining the quality of the larger target model. Args: draft_model: Smaller/faster model used for initial token predictions target_model: Larger/slower model used for verification token: Current input token tensor of shape [1] input_pos: Position index of the token tensor for KV-cache input_pos_maxp1: Maximum position + 1 for managing KV-cache buffer speculative_k: Number of tokens to speculatively generate at once sample_kwargs: Additional sampling parameters (temperature, top_k, top_p) Returns: torch.Tensor: Generated tokens that were either accepted from draft model or resampled from target model """ if speculative_k < 1: raise ValueError(f"speculative_k must be >= 1, got {speculative_k}") # Step 1: Generate candidate tokens using draft model # The draft model autoregressively generates k tokens, keeping track of probabilities draft_input_pos = input_pos.clone() draft_input_pos_maxp1 = input_pos_maxp1 draft_tokens, draft_probs = [], [] draft_token = token for idx in range(speculative_k): logits = draft_model( idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1 ) draft_token, draft_prob = sample(logits, **sample_kwargs) draft_input_pos.add_(1) draft_input_pos_maxp1 += 1 draft_tokens.append(draft_token) draft_probs.append(draft_prob) draft_tokens = torch.cat(draft_tokens) # Step 2: Get target model predictions for comparison # Feed both original token and draft tokens to get target probabilities candidate_tokens = torch.cat((token, draft_tokens)) candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device) candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k target_logits = target_model( idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1 ) # Step 3: Convert target logits to probabilities using same sampling params target_probs = [] for target_logit in target_logits.split(1, dim=1): _, target_prob = sample(target_logit, **sample_kwargs) target_probs.append(target_prob) # Step 4: Accept/reject draft tokens based on probability comparison # Using rejection sampling: keep token if target_prob >= draft_prob. # Otherwise reject with probability 1 - target_prob / draft_prob. # If rejected, sample from an adjusted distribution: norm(max(0, target_prob_distribution - draft_prob_distribution) instead. accepted_tokens = [] for idx in range(len(draft_tokens)): draft_token = draft_tokens[idx].unsqueeze(0) draft_prob = draft_probs[idx][draft_token] target_prob = target_probs[idx][draft_token] # Accept the draft token if the target model is "confident" in it if target_prob >= draft_prob: accepted_tokens.append(draft_token) continue # If not accepted, probabilistically reject it discard_prob = 1 - target_prob / draft_prob should_discard_token = torch.rand(1, device=discard_prob.device) <= discard_prob if not should_discard_token: accepted_tokens.append(draft_token) continue # On rejection: sample new token from adjusted distribution # p'(x) = normalize(max(0, p_target(x) - p_draft(x))) adjusted_distribution = target_probs[idx] - draft_probs[idx] adjusted_distribution = torch.clamp(adjusted_distribution, 0.0) adjusted_distribution = adjusted_distribution / adjusted_distribution.sum() new_token, _ = sample(adjusted_distribution[None, None, ...], apply_softmax=False, **sample_kwargs) return torch.cat((*accepted_tokens, new_token)) # If all draft tokens were accepted: # 1. Update draft model's key-value cache # 2. Sample one more token from target model draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1) new_token, _ = sample(target_logits, **sample_kwargs) return torch.cat((*accepted_tokens, new_token)) @torch.inference_mode() def generate( draft_model: GPT, target_model: GPT, prompt: torch.Tensor, max_returned_tokens: int, *, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, stop_tokens: Tuple[List[int], ...] = (), include_prompt: bool = True, speculative_k: int, ) -> Iterator[torch.Tensor]: """Generates tokens using speculative decoding with a draft and a target model. This function implements token generation using speculative decoding, where a faster draft model makes initial token predictions that are verified by a slower but more accurate target model. Args: draft_model: Smaller/faster model used for initial token predictions target_model: Larger/more accurate model used to verify draft predictions prompt: Input tensor of token ids to generate from, shape [sequence_length] max_returned_tokens: Maximum total tokens (prompt + generated) to return temperature: Sampling temperature (higher = more random, lower = more deterministic) top_k: If set, only sample from the top k most likely next tokens top_p: If <1.0, only sample from tokens whose cumulative probability exceeds top_p stop_tokens: List of token sequences that will stop generation if produced include_prompt: Whether to include prompt tokens in the returned sequence speculative_k: Number of tokens to speculatively generate at each step Returns: - tokens: Tensor of generated token ids - acceptance_rate: Ratio of accepted draft model predictions This implements an optimized decoding process: 1. Both models process the initial prompt 2. Draft model speculatively generates k tokens ahead 3. Target model verifies the draft predictions 4. Accepted tokens are kept, rejected ones trigger resampling 5. Process repeats until max tokens or stop sequence reached """ prompt_size = prompt.size(0) device = prompt.device assert max_returned_tokens > prompt_size, ( f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}." ) if draft_model.max_seq_length < max_returned_tokens - 1: raise NotImplementedError( f"max_seq_length {draft_model.max_seq_length} needs to be >= {max_returned_tokens - 1}" ) if target_model.max_seq_length < max_returned_tokens - 1: raise NotImplementedError( f"max_seq_length {target_model.max_seq_length} needs to be >= {max_returned_tokens - 1}" ) # Step 1: Prefill draft and target models with the prompt. input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. input_pos_maxp1 = ( prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None ) next_token( draft_model, input_pos, prompt.view(1, -1), input_pos_maxp1=input_pos_maxp1, temperature=temperature, top_k=top_k, top_p=top_p, ) token = next_token( target_model, input_pos, prompt.view(1, -1), input_pos_maxp1=input_pos_maxp1, temperature=temperature, top_k=top_k, top_p=top_p, ) # Update position trackers after prompt input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) input_pos_maxp1 += 1 # Step 2: Main generation loop. tokens = [] total_generated, total_accepted = 0, 0 # Track acceptance statistics while input_pos < max_returned_tokens - 1: # Calculate speculative tokens to generate _speculative_k = min(speculative_k, (max_returned_tokens - input_pos - 1).item()) # Get new tokens via speculative decoding new_tokens = speculative_decoding( draft_model=draft_model, target_model=target_model, token=token, input_pos=input_pos, input_pos_maxp1=input_pos_maxp1, speculative_k=_speculative_k, temperature=temperature, top_k=top_k, top_p=top_p, ) # Update statistics accepted_tokens_len = len(new_tokens) total_generated += _speculative_k total_accepted += accepted_tokens_len - 1 # accepted +1 sampled from a target model # Process tokens and check for stop condition should_break = False for new_token in new_tokens: if new_token in stop_tokens: should_break = True break tokens.append(new_token) if should_break: break # Update positions for next iteration input_pos.add_(accepted_tokens_len) input_pos_maxp1 += accepted_tokens_len token = new_tokens[-1].unsqueeze(0) # Finalize generated sequence tokens = torch.stack(tokens) if include_prompt: tokens = torch.cat([prompt, tokens]) acceptance_rate = total_accepted / total_generated if total_generated > 0 else 0.0 return tokens, acceptance_rate def setup_model(config: Config, max_returned_tokens: int, fabric: L.Fabric) -> GPT: """Helper function to setup a model with common configuration.""" with fabric.init_module(empty_init=True): model = GPT(config) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() return fabric.setup_module(model) def load_model(checkpoint_dir: Path, fabric: L.Fabric) -> Tuple[Config, Path]: """Helper function to validate and load model configuration.""" check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device) return config, checkpoint_path @torch.inference_mode() def main( draft_model_checkpoint_dir: Path, target_model_checkpoint_dir: Path, prompt: str = "What food do llamas eat?", *, sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, speculative_k: int = 3, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, precision: Optional[str] = None, compile: bool = False, ) -> None: """Default generation option. Generates text samples based on pre-trained models and a tokenizer. Args: draft_model: Smaller/faster model used for initial token predictions target_model: Larger/more accurate model used to verify draft predictions prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. speculative_k: Number of tokens to speculatively generate at each step top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. compile: Whether to compile the model. """ draft_model_checkpoint_dir = extend_checkpoint_dir(draft_model_checkpoint_dir) target_model_checkpoint_dir = extend_checkpoint_dir(target_model_checkpoint_dir) pprint(locals()) # Setup Fabric precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None and quantize.startswith("bnb."): if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) # Load model configs and checkpoints draft_config, draft_checkpoint_path = load_model(draft_model_checkpoint_dir, fabric) target_config, target_checkpoint_path = load_model(target_model_checkpoint_dir, fabric) # Setup tokenizer and validate draft_tokenizer = Tokenizer(draft_model_checkpoint_dir) target_tokenizer = Tokenizer(target_model_checkpoint_dir) if draft_tokenizer.vocab_size != target_tokenizer.vocab_size: raise ValueError("Draft and target models have different vocab sizes.") tokenizer = target_tokenizer # Setup prompt prompt_style = ( load_prompt_style(target_model_checkpoint_dir) if has_prompt_style(target_model_checkpoint_dir) else PromptStyle.from_config(target_config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens # Initialize models fabric.print(f"Loading draft model {str(draft_checkpoint_path)!r} with {draft_config.__dict__}", file=sys.stderr) fabric.print(f"Loading target model {str(target_checkpoint_path)!r} with {target_config.__dict__}", file=sys.stderr) t0 = time.perf_counter() draft_model = setup_model(draft_config, max_returned_tokens, fabric) target_model = setup_model(target_config, max_returned_tokens, fabric) fabric.print(f"Time to instantiate models: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # Setup compilation if needed if compile: torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True global next_token next_token = torch.compile(next_token, mode="reduce-overhead") # Load model weights t0 = time.perf_counter() load_checkpoint(fabric, draft_model, draft_checkpoint_path) load_checkpoint(fabric, target_model, target_checkpoint_path) fabric.print(f"Time to load the models weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # Generate samples L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y, acceptance_rate = generate( draft_model, target_model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=([tokenizer.eos_id] if tokenizer.eos_id is not None else []), speculative_k=speculative_k, ) t = time.perf_counter() - t0 # Reset KV cache for model in (draft_model, target_model): for block in model.transformer.h: block.attn.kv_cache.reset_parameters() # Print results fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print(f"Acceptance rate: {acceptance_rate * 100:.2f}%") fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/generate/tp.py ================================================ """Tensor-parallel implementation adapted from https://github.com/pytorch-labs/gpt-fast/blob/14df27/tp.py""" import logging import sys import time import warnings from functools import partial from pathlib import Path from pprint import pprint from typing import Literal, Optional, Union import lightning as L import torch import torch._dynamo.config import torch._inductor.config from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only import litgpt.generate.base as generate_base from litgpt.config import Config from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_nvlink_connectivity, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, ) def tensor_parallel_linear(fabric: L.Fabric, linear: torch.nn.Linear, style: str) -> None: world_size = fabric.world_size dim, attr = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}[style] size = getattr(linear, attr) if size % world_size != 0: raise ValueError( f"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})" ) shard = torch.tensor_split(linear.weight, world_size, dim=dim)[fabric.global_rank] # overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support. # the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses linear.weight.data = shard setattr(linear, attr, shard.size(dim)) if linear.bias is not None and dim == 0: shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank] linear.bias = torch.nn.Parameter(shard, requires_grad=linear.bias.requires_grad) def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None: if isinstance(mlp, LLaMAMLP): tensor_parallel_linear(fabric, mlp.fc_1, "colwise") tensor_parallel_linear(fabric, mlp.fc_2, "colwise") tensor_parallel_linear(fabric, mlp.proj, "rowwise") mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size)) elif isinstance(mlp, GptNeoxMLP): tensor_parallel_linear(fabric, mlp.fc, "colwise") tensor_parallel_linear(fabric, mlp.proj, "rowwise") mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size)) elif isinstance(mlp, LLaMAMoE): # we use expert slicing across ranks, alternatively, we could create a expert parallelism group # when the number of experts is a multiple of the world size for expert in mlp.experts: tensor_parallel_mlp(fabric, expert) else: raise NotImplementedError def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: tensor_parallel_linear(fabric, attn.qkv, "colwise") tensor_parallel_linear(fabric, attn.proj, "rowwise") attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) def all_reduce_output(world_size: int, module: torch.nn.Module, ins, outs) -> torch.Tensor: from torch.distributed._functional_collectives import all_reduce return all_reduce(outs, "sum", list(range(world_size))) def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT: for block in model.transformer.h: tensor_parallel_mlp(fabric, block.mlp) tensor_parallel_attn(fabric, block.attn) # update the config values to the shard sizes # this is only relevant for `tensor_parallel_attn`, but it needs to run only once world_size = fabric.world_size attrs = ["n_head", "n_embd", "n_query_groups"] for attr in attrs: size = getattr(model.config, attr) if size % world_size != 0: raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})") setattr(model.config, attr, size // world_size) return model @torch.inference_mode() def main( checkpoint_dir: Path, prompt: str = "What food do llamas eat?", *, sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None, precision: Optional[str] = None, compile: bool = False, ) -> None: """Generation script that uses tensor parallelism to run across devices. Generates text samples based on a pre-trained model and tokenizer. Args: checkpoint_dir: The checkpoint directory to load. prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent to sampling the most probable token, while `top_p=1` samples from the whole distribution. It can be used in conjunction with `top_k` and `temperature` with the following order of application: 1. `top_k` sampling 2. `temperature` scaling 3. `top_p` sampling For more details, see https://arxiv.org/abs/1904.09751 or https://huyenchip.com/2024/01/16/sampling.html#top_p temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. quantize: Whether to quantize the model and using which method: - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md precision: Indicates the Fabric precision setting to use. compile: Whether to compile the model. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) precision = precision or get_default_supported_precision(training=False) plugins = None if quantize is not None: if compile: raise NotImplementedError # untested if "mixed" in precision: raise ValueError("Quantization and mixed precision is not supported.") if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0: warnings.warn( "LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization." ) dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] bnb_logger = logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes") bnb_logger.setLevel(logging.DEBUG) bnb_logger.debug = rank_zero_only(bnb_logger.debug) plugins = BitsandbytesPrecision(quantize[4:], dtype) precision = None # set "ddp" as the strategy for the launching functionality, but there's no data-parallelism fabric = L.Fabric(devices="auto", strategy="ddp", precision=precision, plugins=plugins) if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1: check_nvlink_connectivity(fabric) fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") model_file = "lit_model.pth" checkpoint_path = checkpoint_dir / model_file tokenizer = Tokenizer(checkpoint_dir) prompt_style = ( load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config) ) prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt) encoded = tokenizer.encode(prompt, device=fabric.device) prompt_length = encoded.size(0) max_returned_tokens = prompt_length + max_new_tokens fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device # so that the CPU RAM doesn't OOM with larger models for rank in range(fabric.world_size): if fabric.global_rank == rank: t0 = time.perf_counter() state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") model.load_state_dict(state_dict, assign=True) print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # cannot use `.setup_module` because it will wrap with DDP model = fabric._precision.convert_module(model) t0 = time.perf_counter() model = tensor_parallel(fabric, model) print( f"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr, ) with fabric.init_tensor(): # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # the rope cache which is on meta device model.cos, model.sin = model.rope_cache() # enable the kv cache model.set_kv_cache(batch_size=1) model.eval() t0 = time.perf_counter() model = fabric.to_device(model) print(f"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) fabric.barrier() if compile: torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True generate_base.next_token = torch.compile(generate_base.next_token, mode="reduce-overhead") L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id ) t = time.perf_counter() - t0 for block in model.transformer.h: block.attn.kv_cache.reset_parameters() fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) ================================================ FILE: litgpt/lora.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. # Derived from https://github.com/microsoft/LoRA # ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ r""" Low Ranking Adaptation for LLMs scheme. ┌───────────────────┐ ┆ h ┆ └───────────────────┘ ▲ | + / \ ┌─────────────────┐ ╭───────────────╮ Matrix initialization: ┆ ┆ \ B / B = 0 ┆ pretrained ┆ \ r*d / A = N(0, sigma^2) ┆ weights ┆ ╰─────────╯ ┆ ┆ | r | r - rank ┆ W e R^(d*d) ┆ | ◀─────▶ | ┆ ┆ ╭─────────╮ └─────────────────┘ / A \ ▲ / d*r \ \ ╰───────────────╯ \ ▲ \ / \ / ┌───────────────────┐ ┆ x ┆ └───────────────────┘ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen pretrained weights and thus fine-tune the model. The goal of this approach is to move weight updates into a separate matrix which is decomposed with two matrices of a lower rank. """ import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Type, Union import torch import torch.nn as nn from torch.nn import functional as F from typing_extensions import Self import litgpt from litgpt.config import Config as BaseConfig from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights class LoRALayer(nn.Module): def __init__(self, r: int, lora_alpha: int, lora_dropout: float): """Store LoRA specific attributes in a class. Args: r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) lora_alpha: alpha is needed for scaling updates as alpha/r "This scaling helps to reduce the need to retune hyperparameters when we vary r" https://arxiv.org/pdf/2106.09685.pdf (section 4.1) lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) """ super().__init__() assert r >= 0 self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False class LoRALinear(LoRALayer): # LoRA implemented in a dense layer def __init__( self, # ↓ this part is for pretrained weights in_features: int, out_features: int, # ↓ the remaining part is for LoRA r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs: Any, ): """LoRA wrapper around linear class. This class has three weight matrices: 1. Pretrained weights are stored as `self.linear.weight` 2. LoRA A matrix as `self.lora_A` 3. LoRA B matrix as `self.lora_B` Only LoRA's A and B matrices are updated, pretrained weights stay frozen. Args: in_features: number of input features of the pretrained weights out_features: number of output features of the pretrained weights r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) lora_alpha: alpha is needed for scaling updates as alpha/r "This scaling helps to reduce the need to retune hyperparameters when we vary r" https://arxiv.org/pdf/2106.09685.pdf (section 4.1) lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) """ super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) self.linear = torch.nn.Linear(in_features, out_features, **kwargs) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(torch.empty((r, in_features))) self.lora_B = nn.Parameter(torch.empty((out_features, r))) self.scaling = self.lora_alpha / self.r self.reset_parameters() def reset_parameters(self) -> None: """Reset all the weights, even including pretrained ones.""" if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def get_lora_AB(self) -> torch.Tensor: """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" return (self.lora_B @ self.lora_A) * self.scaling def merge(self) -> None: """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" if self.r > 0 and not self.merged: pretrained_dtype = self.linear.weight.data.dtype lora_data = self.get_lora_AB() # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result if pretrained_dtype == torch.uint8: import bitsandbytes as bnb weight = self.linear.weight # dequantize the pretrained weights weight_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state).to(lora_data.dtype) # add pretrained and LoRA weights weight_data += lora_data # assign updated weights and quantize by moving to CUDA device self.linear.weight = bnb.nn.Params4bit(weight_data, requires_grad=False, **weight.__dict__) self.linear.weight.cuda(weight.device) else: # self.linear might be on CPU and lora_data on CUDA # the inplace add will preserve the dtype of linear.weight self.linear.weight.data += lora_data.to(device=self.linear.weight.data.device) self.merged = True def forward(self, x: torch.Tensor) -> torch.Tensor: # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights pretrained = self.linear(x) if self.r == 0 or self.merged: return pretrained lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return pretrained + lora class LoRAQKVLinear(LoRALinear): # LoRA implemented in a dense layer def __init__( self, # ↓ this part is for pretrained weights in_features: int, out_features: int, # ↓ the remaining part is for LoRA head_size: int, n_head: int, n_query_groups: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, **kwargs: Any, ): """LoRA wrapper around linear class that is used for calculation of q, k and v matrices. This class has three weight matrices: 1. Pretrained weights are stored as `self.linear.weight` 2. LoRA A matrix as `self.lora_A` 3. LoRA B matrix as `self.lora_B` Only LoRA's A and B matrices are updated, pretrained weights stay frozen. Args: in_features: number of input features of the pretrained weights out_features: number of output features of the pretrained weights head_size: size of a single attention head n_head: number of attention heads n_query_groups: number of query groups (see diagram in `litgpt/config.py`) r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) lora_alpha: alpha is needed for scaling updates as alpha/r "This scaling helps to reduce the need to retune hyperparameters when we vary r" https://arxiv.org/pdf/2106.09685.pdf (section 4.1) lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` and `value` but keep `key` without weight updates we should pass `[True, False, True]` """ super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) self.linear = torch.nn.Linear(in_features, out_features, **kwargs) self.head_size = head_size self.n_head = n_head self.n_query_groups = n_query_groups if isinstance(enable_lora, bool): enable_lora = [enable_lora] * 3 assert len(enable_lora) == 3 self.enable_lora = enable_lora # Actual trainable parameters # To better understand initialization let's imagine that we have such parameters: # ⚬ in_features: 128 (embeddings_size) # ⚬ out_features: 384 (3 * embedding_size) # ⚬ r: 2 # ⚬ enable_lora: [True, False, True] if r > 0 and any(enable_lora): self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128) enable_q, enable_k, enable_v = enable_lora # qkv_shapes will be used to split a tensor with weights correctly qkv_shapes = ( # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) # might not be equal to `head_size * n_head`, thus we use it directly here head_size * n_head * enable_q, head_size * n_query_groups * enable_k, head_size * n_query_groups * enable_v, ) self.qkv_shapes = [s for s in qkv_shapes if s] self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2)) # Notes about shapes above # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in # F.linear function weights are automatically transposed. In addition conv1d requires channels to # be before seq length # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is # 128*2; 2 tells to have two channels per group for group convolution # Scaling: # This balances the pretrained model`s knowledge and the new task-specific adaptation # https://lightning.ai/pages/community/tutorial/lora-llm/ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can # tune these values to your needs. This value can be even slightly greater than 1.0! # https://github.com/cloneofsimo/lora self.scaling = self.lora_alpha / self.r self.reset_parameters() @property def lora_ind(self) -> torch.Tensor: """Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used.""" # Indices are needed to properly pad weight updates with zeros. if not hasattr(self, "_lora_ind"): enable_q, enable_k, enable_v = self.enable_lora q_embd_size = self.head_size * self.n_head kv_embd_size = self.head_size * self.n_query_groups lora_ind = [] if enable_q: lora_ind.extend(range(0, q_embd_size)) if enable_k: lora_ind.extend(range(q_embd_size, q_embd_size + kv_embd_size)) if enable_v: lora_ind.extend(range(q_embd_size + kv_embd_size, self.linear.out_features)) self.register_buffer( "_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False ) return self._lora_ind def zero_pad(self, x: torch.Tensor) -> torch.Tensor: """Properly pad the last dimension of weight updates with zeros. If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, then the weights update should be: [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], [....................................], [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] ↑ ↑ ↑ ________________________________________ | query | key | value | ---------------------------------------- Args: x: tensor with weights update that will be padded with zeros if necessary Returns: A tensor with weight updates and zeros for deselected q, k or v """ # we need to do zero padding only if LoRA is disabled for one of QKV matrices if all(self.enable_lora): return x # Let's image that: # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) # ⚬ embeddings_size: 128 # ⚬ self.linear.out_features: 384 (3 * embeddings_size) # ⚬ enable_lora: [True, False, True] # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but # only for key updates (this is where self.lora_ind comes in handy) result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384) if result.device.type == "mps": result[..., self.lora_ind] = x return result else: return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384) def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. If the number of heads is equal to the number of query groups - grouped queries are disabled (see scheme in `litgpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the input and weight matrices will be split in equally sized parts and applied separately (like having multiple conv layers side by side). Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, apply each part of the weight matrix to the corresponding input's part and concatenate the result. Args: input: input matrix of shape (B, C, T) weight: weight matrix of shape (C_output, rank, 1). "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). Returns: A tensor with a shape (B, C_output, T) """ if self.n_head == self.n_query_groups: return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T) # Notation: # ⚬ N: number of enabled LoRA layers (self.enable_lora) # ⚬ C_output': embeddings size for each LoRA layer (not equal in size) # ⚬ r: rank of all LoRA layers (equal in size) input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1, # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" # Let's assume that: # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) # ⚬ self.lora_A.data: (4, 128) # ⚬ self.lora_B.data: (256, 2) lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128) def merge(self) -> None: """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" if self.r > 0 and any(self.enable_lora) and not self.merged: super().merge() def forward(self, x: torch.Tensor) -> torch.Tensor: """Do the forward pass. If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. If not, then multiply pretrained weights with input, apply LoRA on input and do summation. Args: x: input tensor of shape (batch_size, context_length, embedding_size) Returns: Output tensor of shape (batch_size, context_length, 3 * embedding_size) """ # Let's assume that: # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) # ⚬ self.lora_A.data: (4, 128) # ⚬ self.lora_B.data: (256, 2) # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights pretrained = self.linear(x) if self.r == 0 or not any(self.enable_lora) or self.merged: return pretrained after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) # For F.conv1d: # ⚬ input: input tensor of shape (mini-batch, in_channels, iW) # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. Args: model: model with LoRA layers bias: ``"none"``: all bias weights will be frozen, ``"lora_only"``: only bias weight for LoRA layers will be unfrozen, ``"all"``: all bias weights will be unfrozen. Raises: NotImplementedError: if `bias` not in ["none", "lora_only", "all"] """ # freeze all layers except LoRA's for n, p in model.named_parameters(): if "lora_" not in n: p.requires_grad = False # depending on the `bias` value unfreeze bias weights if bias == "none": return if bias == "all": for n, p in model.named_parameters(): if "bias" in n: p.requires_grad = True elif bias == "lora_only": for m in model.modules(): if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: m.bias.requires_grad = True else: raise NotImplementedError def lora_filter(key: str, value: Any) -> bool: return "lora_" in key @dataclass class Config(BaseConfig): """ Args: lora_r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) lora_alpha: alpha is needed for scaling updates as alpha/r "This scaling helps to reduce the need to retune hyperparameters when we vary r" https://arxiv.org/pdf/2106.09685.pdf (section 4.1) lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) lora_*: whether to apply LoRA to the specified weights or not """ lora_r: int = 0 lora_alpha: int = 1 lora_dropout: float = 0.0 lora_query: bool = False lora_key: bool = False lora_value: bool = False lora_projection: bool = False lora_mlp: bool = False lora_head: bool = False @property def mlp_class(self) -> Type: return getattr(litgpt.lora, self.mlp_class_name) class GPT(BaseModel): # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config self.lm_head = create_lora_linear( config, config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias, use_r=config.lora_head, ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) self.mask_cache: Optional[torch.Tensor] = None self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" super()._init_weights(module) if isinstance(module, LoRALinear): module.reset_parameters() def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class Block(BaseBlock): def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) self.mlp = config.mlp_class(config) class CausalSelfAttention(BaseCausalSelfAttention): def __init__(self, config: Config, block_idx: int) -> None: super().__init__(config, block_idx) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = LoRAQKVLinear( in_features=config.n_embd, out_features=shape, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, enable_lora=(config.lora_query, config.lora_key, config.lora_value), bias=config.bias or config.attn_bias, # for MQA/GQA support head_size=config.head_size, n_head=config.n_head, n_query_groups=config.n_query_groups, ) # output projection self.proj = create_lora_linear( config, config.head_size * config.n_head, config.n_embd, use_r=config.lora_projection, ) def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" mapping = { "qkv.weight": "qkv.linear.weight", "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) for attr in ("weight", "bias"): legacy_key = f"{prefix}attn.linear.{attr}" current_key = f"{prefix}qkv.linear.{attr}" if legacy_key in state_dict: state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def create_lora_linear( config: Config, in_size: int, out_size: int, bias: Optional[Union[float, bool]] = None, use_r: Optional[bool] = None, ) -> LoRALinear: if bias is None: bias = config.bias if use_r is None: use_r = config.lora_mlp return LoRALinear( in_size, out_size, bias=bias, r=(config.lora_r if use_r else 0), lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, ) class GptNeoxMLP(litgpt.model.GptNeoxMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) self.fc = create_lora_linear(config, config.n_embd, config.intermediate_size) self.proj = create_lora_linear(config, config.intermediate_size, config.n_embd) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { "fc.weight": "fc.linear.weight", "fc.bias": "fc.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class LLaMAMLP(litgpt.model.LLaMAMLP): def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: nn.Module.__init__(self) self.intermediate_size = intermediate_size or config.intermediate_size self.fc_1 = create_lora_linear(config, config.n_embd, self.intermediate_size) self.fc_2 = create_lora_linear(config, config.n_embd, self.intermediate_size) self.proj = create_lora_linear(config, self.intermediate_size, config.n_embd) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { "fc_1.weight": "fc_1.linear.weight", "fc_1.bias": "fc_1.linear.bias", "fc_2.weight": "fc_2.linear.weight", "fc_2.bias": "fc_2.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) class LLaMAMoE(litgpt.model.LLaMAMoE): def __init__(self, config: Config) -> None: nn.Module.__init__(self) self.gate = create_lora_linear(config, config.n_embd, config.n_expert, bias=False) self.experts = nn.ModuleList( LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert) ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = {"gate.weight": "gate.linear.weight"} state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def merge_lora_weights(model: GPT) -> None: """Merge LoRA weights into the full-rank weights to speed up inference.""" for module in model.modules(): if isinstance(module, LoRALinear): module.merge() ================================================ FILE: litgpt/model.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Full definition of a decoder-only transformer-based language model, all of it in this single file. Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. """ import math from functools import partial from typing import Any, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from typing_extensions import Self from litgpt.config import Config from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble class GPT(nn.Module): def __init__(self, config: Config) -> None: super().__init__() assert config.padded_vocab_size is not None self.config = config self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) self.mask_cache: Optional[torch.Tensor] = None self.max_seq_length = self.config.block_size @property def max_seq_length(self) -> int: return self._max_seq_length @max_seq_length.setter def max_seq_length(self, value: int) -> None: """ When doing inference, the sequences used might be shorter than the model's context length. This allows setting a smaller number to avoid allocating unused memory """ if value > self.config.block_size: raise ValueError( f"Cannot attend to {value}, block size is only {self.config.block_size}." " This is likely because the input text exceeds the supported context length of this model." ) self._max_seq_length = value if not hasattr(self, "cos"): # first call cos, sin = self.rope_cache() self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) # override elif value != self.cos.size(0): self.cos, self.sin = self.rope_cache(device=self.cos.device) # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know # if the kv cache is expected if self.mask_cache is not None and self.mask_cache.shape[-1] < value: print( f"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!" ) def reset_parameters(self) -> None: # Trigger resetting the rope-cache self.cos, self.sin = self.rope_cache(device=self.cos.device) def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`.""" if isinstance(module, GroupedTopkRouter): torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, input_pos_maxp1: Optional[int] = None, lm_head_chunk_size: int = 0, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ If `input_pos` is provided, the KV cache uses K and V vectors for positions smaller than entries in `input_pos`. For efficiency, pass `input_pos_maxp1` as `max(input_pos) + 1` if already available from your forward algorithm. This slices the KV cache buffers and speeds up multi-head attention. Without `input_pos_maxp1`, the computation uses the full KV cache (`max_seq_length`) with masking applied. Note that inferring `input_pos_maxp1` from `input_pos` causes graph breaks and prevents compilation. Args: idx: Token indices of input sequences, shape `(B, T)`, where `B` is batch size. input_pos: Optional. Positions of input tokens. The default is `arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index). input_pos_maxp1: Optional. See above. lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final `lm_head` computation is done in chunks of this size. Returns: Logit outputs, shape `(B, T, config.padded_vocab_size)`. If `lm_head_chunk_size > 0`, this is a list of chunks of shape `(B, lm_head_chunk_size, config.padded_vocab_size)`, the final entry can be shorter. """ T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") if input_pos is not None: # use the kv cache if input_pos.dim() > 2: # otherwise, things go wrong in `apply_rope` raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") if input_pos.shape[-1] != T: raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) if input_pos.dim() == 1: cos = cos.unsqueeze(0) sin = sin.unsqueeze(0) if self.mask_cache is None: raise TypeError("You need to call `gpt.set_kv_cache()`") mask = batched_index_select(self.mask_cache, 2, input_pos) if mask.dim() > 4: # the mask cache has a batch dim of 1 in addition to the one # we get if input_pos has a batch dimension mask = mask.view(*(mask.shape[0:1] + mask.shape[2:])) if input_pos_maxp1 is not None: # Shorten final dimension so it just covers all `input_pos` entries if input_pos_maxp1 > self.max_seq_length: raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})") mask = mask[..., :input_pos_maxp1] else: # unsqueeze to have a batch dimension cos = self.cos[:T].unsqueeze(0) sin = self.sin[:T].unsqueeze(0) # `cos`, `sin` have shape (1, T, config.rope_n_elem) mask = None # defaults to causal mask input_pos_maxp1 = None x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) for block_idx, block in enumerate(self.transformer.h): if self.config.rope_indices is not None: x = block( x, cos[..., self.config.rope_indices[block_idx]], sin[..., self.config.rope_indices[block_idx]], mask, input_pos, input_pos_maxp1, ) else: x = block(x, cos, sin, mask, input_pos, input_pos_maxp1) x = self.transformer.ln_f(x) clamp_head = ( partial(do_softcapping, thresh=self.config.final_logit_softcapping) if self.config.final_logit_softcapping is not None else nn.Identity() ) if lm_head_chunk_size > 0: # chunk the lm head logits to reduce the peak memory used by autograd return [clamp_head(self.lm_head(x_i)) for x_i in x.split(lm_head_chunk_size, dim=1)] else: return clamp_head(self.lm_head(x)) # (B, T, padded_vocab_size) @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: if self.config.rope_adjustments is None: extra_config = None else: # Check for mutually exclusive parameter sets llama3_params = ["low_freq_factor", "high_freq_factor"] yarn_params = ["beta_fast", "beta_slow"] has_llama3 = any(param in self.config.rope_adjustments for param in llama3_params) has_yarn = any(param in self.config.rope_adjustments for param in yarn_params) if has_llama3 and has_yarn: raise ValueError( "RoPE adjustments cannot contain both Llama3 parameters (low_freq_factor, high_freq_factor) " "and YaRN parameters (beta_fast, beta_slow). These are mutually exclusive." ) # Llama3-style RoPE if has_llama3: adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] if all(params_present): extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required} else: missing_params = [ param for param, present in zip(adjusted_params_required, params_present) if not present ] raise ValueError( f"The following Llama3 RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " "All Llama3 parameters must be specified together." ) # YaRN-style RoPE elif has_yarn: # Required: factor, beta_fast, beta_slow, original_max_seq_len # Optional: mscale, mscale_all_dim yarn_required_params = ["factor", "beta_fast", "beta_slow", "original_max_seq_len"] params_present = [param in self.config.rope_adjustments for param in yarn_required_params] if not all(params_present): missing_params = [ param for param, present in zip(yarn_required_params, params_present) if not present ] raise ValueError( f"The following YaRN RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " "All YaRN required parameters must be specified together." ) extra_config = {name: self.config.rope_adjustments[name] for name in yarn_required_params} # Add optional YaRN parameters for param in ["mscale", "mscale_all_dim"]: if param in self.config.rope_adjustments: extra_config[param] = self.config.rope_adjustments[param] # Linear or standard RoPE elif "factor" in self.config.rope_adjustments: # linear RoPE adjusted_params_required = ["factor"] extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required} else: extra_config = None # uses standard RoPE return build_rope_cache( seq_len=self.max_seq_length, n_elem=self.config.rope_n_elem, device=device, condense_ratio=self.config.rope_condense_ratio, base=self.config.rope_base, extra_config=extra_config, rope_local_base_freq=self.config.rope_local_base_freq, ) def rope_cache_length(self) -> int: """ Extract the head dimension (n_elem) from RoPE cache regardless of shape. The RoPE cache can have different shapes depending on model configuration: - Standard RoPE: (seq_len, n_elem) - 2D tensor - Dual RoPE (local/global): (seq_len, n_elem, 2) - 3D tensor Returns: int: n_elem (head dimension for RoPE) """ return self.cos.size(1) def set_kv_cache( self, batch_size: int, max_seq_length: Optional[int] = None, rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: if rope_cache_length is None: rope_cache_length = self.rope_cache_length() if max_seq_length is None: max_seq_length = self.max_seq_length # initialize the kv cache for all blocks for block in self.transformer.h: block.attn.kv_cache = block.attn.build_kv_cache( batch_size, max_seq_length, rope_cache_length, device, dtype, ) if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask # for the kv-cache support (only during inference), we only create it in that situation self.mask_cache = build_mask_cache(max_seq_length, device) def clear_kv_cache(self) -> None: self.mask_cache = None for block in self.transformer.h: block.attn.kv_cache = None class Block(nn.Module): def __init__( self, config: Config, block_idx: int, ) -> None: super().__init__() if not config.parallel_residual and config.shared_attention_norm: raise NotImplementedError( "No checkpoint amongst the ones we support uses this configuration" " (non-parallel residual and shared attention norm)." ) self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps) self.attn = ( CausalSelfAttention(config, block_idx) if not config.latent_attention else MultiheadLatentAttention(config, block_idx) ) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() ) self.norm_2 = ( nn.Identity() if not config.norm_2 else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)) ) self.mlp = config.mlp_class(config) if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace: self.mlp = LLaMAMLP(config) self.post_mlp_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() ) self.config = config def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual ┌─ x ┌─ x ──────────────────┐ Note: if `shared_attention_norm` is True, │ ↓ │ ↓ ↓ the output from `norm_1` is reused │ norm_1 │ norm_1 ───────► norm_2 │ ↓ │ ↓ ↓ │ attn │ attn MLP │ ↓ │ ↓ ↓ | post_attn_norm | post_attn_norm post_mlp_norm | ↓ | ↓ ↓ ┌─ └► + └► + ◄─────────────────┘ | ↓ │ norm_2 │ ↓ │ MLP │ ↓ | post_mlp_norm | ↓ └───► + """ x_normed = self.norm_1(x) attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1) attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: if not self.config.shared_attention_norm: x_normed = self.norm_2(x) x = attention_output + x else: x = attention_output + x x_normed = self.norm_2(x) return self.post_mlp_norm(self.mlp(x_normed)) + x class CausalSelfAttention(nn.Module): def __init__(self, config: Config, block_idx: int) -> None: super().__init__() # key, query and value projections for all heads, but in a batch self.qkv = nn.Linear( config.n_embd, (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries bias=config.bias or config.attn_bias, ) # output projection self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = False if config.sliding_window_size is not None and config.sliding_window_indices is not None: self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] if config.norm_qk: norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size norm_k_size = ( config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size ) self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps) self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps) else: self.norm_q = self.norm_k = None if config.rope_adjustments is not None: mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) scaling_factor = config.rope_adjustments.get("factor", None) if mscale_all_dim and scaling_factor: # YaRN self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) else: self.mscale = 1.0 else: self.mscale = 1.0 self.config = config self.block_idx = block_idx def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: # Notation: # - B | batch size # - T | time-step (sequence length) # - C | model's embeddings size (n_embd) # - C* | attentions's embeddings size # - hs | head size # - nh_(q,k,v) | number of heads for query, key and value # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads # alternative notation: num_kv_groups = n_query_groups # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ │ │ │ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ # MHA GQA MQA # n_query_groups=4 n_query_groups=2 n_query_groups=1 # # credit https://arxiv.org/pdf/2305.13245.pdf head_size = self.config.head_size n_head = self.config.n_head n_query_groups = self.config.n_query_groups rope_n_elem = self.config.rope_n_elem B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` # instead of individually multiplying the input `x` with the respective weight matrices. qkv = self.qkv(x) # (B, T, 3xC*) # Define query, key and value sizes. # If grouped/multi query is enabled, these sizes are not equal (see the diagram above). query_size = n_head * head_size key_size = value_size = n_query_groups * head_size # Split qkv into query, key and value matrices. q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) if self.config.norm_qk and self.config.norm_qk_type == "olmo2": q = self.norm_q(q) k = self.norm_k(k) # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs). # The original GQA paper is followed here and the term query groups is used. # alternative notation: Query groups are also referred to as KV groups. q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) k = k.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs) v = v.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs) # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector # of size `hs`. q = q.transpose(1, 2) # (B, nh_q, T, hs) k = k.transpose(1, 2) # (B, nh_k, T, hs) v = v.transpose(1, 2) # (B, nh_v, T, hs) if self.config.norm_qk and self.config.norm_qk_type == "default": q = self.norm_q(q) k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. if self.config.rope_interleave: q_roped = apply_rope_interleave(q[..., :rope_n_elem], cos, sin) k_roped = apply_rope_interleave(k[..., :rope_n_elem], cos, sin) else: q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) if self.apply_sliding_window_attention: actual_kv_len = k.size(2) if mask is not None and mask.size(-1) != actual_kv_len: mask = mask[..., :actual_kv_len] if input_pos_maxp1 is not None: # Subselect along sequence dimension k = k[..., :input_pos_maxp1, :] v = v[..., :input_pos_maxp1, :] # k, v: (B, nh_k, input_pos_maxp1, hs) # If input_pos_maxp1 is None -> max_seq_length # Grouped queries: balance the number of heads across all three matrices. # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. if n_query_groups != n_head and (input_pos is None or n_query_groups != 1): q_per_kv = n_head // n_query_groups k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) if self.apply_sliding_window_attention: """ Global Window Sliding window Sliding window attention mask + bias = attention mask ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ │ True False False False │ │ True True True True │ │ True False False False │ │ True True False False │ │ True True True True │ │ True True False False │ │ True True True False │ │ False True True True │ │ False True True False │ │ True True True True │ │ False False True True │ │ False False True True │ └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ """ if input_pos is None: if mask is None: mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), float("-inf")) mask = mask.view(1, 1, *mask.shape) sliding_window_mask = torch.full((T, T), float("-inf"), dtype=q.dtype, device=q.device) for i in range(T): window_start = max(0, i - self.config.sliding_window_size + 1) sliding_window_mask[i, window_start : i + 1] = 0.0 sliding_window_mask = sliding_window_mask.view(1, 1, T, T) mask = sliding_window_mask # Efficient attention using Flash Attention CUDA kernels. # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) # Re-assemble all head outputs side by side. y = y.reshape(B, T, head_size * n_head) # Output projection. return self.proj(y) # (B, T, C) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) scale = scale * self.mscale * self.mscale # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: scores = q @ k.mT * scale scores = do_softcapping(scores, self.config.attention_logit_softcapping) if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) scores = scores + mask scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v else: y = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None ) return y.transpose(1, 2) def build_kv_cache( self, batch_size: int, max_seq_length: int, rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": if self.apply_sliding_window_attention and self.config.sliding_window_size is not None: effective_cache_size = min(max_seq_length, self.config.sliding_window_size) else: effective_cache_size = max_seq_length v_shape = (batch_size, self.config.n_query_groups, effective_cache_size, self.config.head_size) if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError( "Please pass the `rope_cache_length` parameter. " "Use `rope_cache_length=model.rope_cache_length()` to extract it automatically." ) k_shape = v_shape else: k_shape = ( batch_size, self.config.n_query_groups, effective_cache_size, rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) return KVCache( k_shape, v_shape, device=device, dtype=dtype, is_sliding_window=self.apply_sliding_window_attention, sliding_window_size=self.config.sliding_window_size if self.apply_sliding_window_attention else None, ) def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" for attr in ("weight", "bias"): legacy_key = f"{prefix}attn.{attr}" current_key = f"{prefix}qkv.{attr}" if legacy_key in state_dict: state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) class MultiheadLatentAttention(nn.Module): def __init__(self, config: Config, block_idx: int) -> None: super().__init__() self.q_a_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=config.attn_bias) self.q_a_norm = RMSNorm(config.q_lora_rank, eps=config.norm_eps) self.q_b_proj = nn.Linear(config.q_lora_rank, config.n_head * config.qk_head_dim, bias=config.bias) self.kv_a_proj_with_mqa = nn.Linear( config.n_embd, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attn_bias ) self.kv_a_norm = RMSNorm(config.kv_lora_rank, eps=config.norm_eps) self.kv_b_proj = nn.Linear( config.kv_lora_rank, config.n_query_groups * (config.qk_nope_head_dim + config.v_head_dim), bias=config.bias, ) # output projection self.proj = nn.Linear(config.n_head * config.v_head_dim, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None if config.rope_adjustments is not None: mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) scaling_factor = config.rope_adjustments.get("factor", None) if mscale_all_dim and scaling_factor: # YaRN self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) else: self.mscale = 1.0 else: self.mscale = 1.0 self.config = config self.block_idx = block_idx def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: # Notation: # - B | batch size # - T | time-step (sequence length) # - C | model's embeddings size (n_embd) # - C* | attentions's embeddings size # - hs | head size # - nh_(q,k,v) | number of heads for query, key and value # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads # alternative notation: num_kv_groups = n_query_groups B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) q = self.q_b_proj(self.q_a_norm(self.q_a_proj(x))) # (B, T, n_head * qk_head_dim) q = q.view(B, T, -1, self.config.qk_head_dim) # (B, T, n_head, qk_head_dim) q = q.transpose(1, 2) # (B, n_head, T, qk_head_dim) q_pass, q_rot = torch.split(q, [self.config.qk_nope_head_dim, self.config.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(x) # (B, T, kv_lora_rank + qk_rope_head_dim) k_pass, k_rot = torch.split(compressed_kv, [self.config.kv_lora_rank, self.config.qk_rope_head_dim], dim=-1) k_pass = self.kv_b_proj(self.kv_a_norm(k_pass)) k_pass = k_pass.view(B, T, self.config.n_query_groups, -1) k_pass = k_pass.transpose(1, 2) k_pass, v = torch.split(k_pass, [self.config.qk_nope_head_dim, self.config.v_head_dim], dim=-1) k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. if self.config.rope_interleave: q_roped = apply_rope_interleave(q_rot, cos, sin) k_roped = apply_rope_interleave(k_rot, cos, sin) else: q_roped = apply_rope(q_rot, cos, sin) k_roped = apply_rope(k_rot, cos, sin) k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim) q = torch.cat((q_pass, q_roped), dim=-1) k = torch.cat((k_pass, k_roped), dim=-1) # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) if input_pos_maxp1 is not None: # Subselect along sequence dimension k = k[..., :input_pos_maxp1, :] v = v[..., :input_pos_maxp1, :] # k, v: (B, nh_k, input_pos_maxp1, hs) # If input_pos_maxp1 is None -> max_seq_length # Grouped queries: balance the number of heads across all three matrices. # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): q_per_kv = self.config.n_head // self.config.n_query_groups k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) # Efficient attention using Flash Attention CUDA kernels. # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) # Re-assemble all head outputs side by side. y = y.reshape(B, T, self.config.n_head * self.config.v_head_dim) # Output projection. return self.proj(y) # (B, T, C) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.qk_head_dim) scale = scale * self.mscale * self.mscale # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: scores = q @ k.mT * scale scores = do_softcapping(scores, self.config.attention_logit_softcapping) if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) scores = scores + mask scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v else: y = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None ) return y.transpose(1, 2) def build_kv_cache( self, batch_size: int, max_seq_length: int, rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": v_shape = (batch_size, self.config.n_head, max_seq_length, self.config.v_head_dim) k_shape = (batch_size, self.config.n_head, max_seq_length, self.config.qk_head_dim) if rope_cache_length is not None: print("Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!") if self.config.rotary_percentage != 1.0: print("Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!") return KVCache(k_shape, v_shape, device=device, dtype=dtype) class GptNeoxMLP(nn.Module): def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: super().__init__() self.intermediate_size = intermediate_size or config.intermediate_size self.fc = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) x = F.gelu(x, approximate=self.config.gelu_approximate) return self.proj(x) class LLaMAMLP(nn.Module): def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: super().__init__() self.intermediate_size = intermediate_size or config.intermediate_size self.fc_1 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) self.fc_2 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = F.silu(x_fc_1) * x_fc_2 return self.proj(x) class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = F.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) class LLaMAMoE(nn.Module): def __init__(self, config: Config) -> None: super().__init__() self.gate = ( nn.Linear(config.n_embd, config.n_expert, bias=False) if not config.n_expert_groups else GroupedTopkRouter(config) ) self.experts = nn.ModuleList( LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert) ) if config.n_shared_expert: self.shared_experts = LLaMAMLP( config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert ) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: """ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 See also figure 1 in https://arxiv.org/abs/2211.15841 """ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) residual_x = x.clone() x = x.view(-1, C) # (B*T, C) if not self.config.n_expert_groups: router = self.gate(x) # (B*T, n_expert) probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) else: probs, indices = self.gate(x) if self.config.routed_scaling_factor != 1.0: probs = probs * self.config.routed_scaling_factor masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device) masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) y = torch.zeros_like(x) # (B*T, C) for mask, expert in zip(masks, self.experts): token_idx, expert_idx = torch.where(mask) y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx]) y = y.view(B, T, C) if self.config.n_shared_expert: y = y + self.shared_experts(residual_x) return y class GroupedTopkRouter(nn.Module): """ Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py. DeepseekV3TopkRouter class. """ def __init__(self, config: Config) -> None: super().__init__() self.config = config self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd)) self.register_buffer("e_score_correction_bias", torch.zeros(config.n_expert)) @torch.no_grad() def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor: scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups) .topk(self.config.n_topk_scores_per_group, dim=-1)[0] # Top k scores for each group .sum(dim=-1) ) group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups) .reshape(-1, self.config.n_expert) ) scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1] return topk_indices def forward(self, x: torch.Tensor) -> torch.Tensor: router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32)) scores = router_logits.sigmoid() topk_indices = self.get_topk_indices(scores) topk_weights = scores.gather(1, topk_indices) if self.config.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator return topk_weights, topk_indices # ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def build_rope_cache( seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1, extra_config: Optional[dict] = None, rope_local_base_freq: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Enhanced Transformer with Rotary Position Embedding. Args: seq_len (int): Sequence length. n_elem (int): Number of elements (head dimension). device (torch.device, optional): Device for tensor allocations. base (int, optional): Base for computing inverse frequencies. condense_ratio (int, optional): Ratio to condense the position indices. extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2) Returns: Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. Shapes are `(seq_len, n_elem)`. """ # Compute the inverse frequencies theta theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) # Initialize attention scaling factor (modified for YaRN) attention_scaling = 1.0 if extra_config is not None: factor = extra_config["factor"] # Check YaRN first (has beta_fast/beta_slow) if "beta_fast" in extra_config or "beta_slow" in extra_config: # YaRN-style RoPE scaling beta_fast = extra_config["beta_fast"] beta_slow = extra_config["beta_slow"] original_max_seq_len = extra_config["original_max_seq_len"] # Calculate attention scaling factor based on mscale and mscale_all_dim mscale = extra_config.get("mscale") mscale_all_dim = extra_config.get("mscale_all_dim") if mscale and mscale_all_dim: attention_scaling = yarn_get_mscale(factor, mscale) / yarn_get_mscale(factor, mscale_all_dim) elif mscale_all_dim: attention_scaling = yarn_get_mscale(factor, mscale_all_dim) elif mscale: attention_scaling = yarn_get_mscale(factor, mscale) # else: attention_scaling remains 1.0 # Create two frequency sets: extrapolation (unscaled) and interpolation (scaled) pos_freqs = base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem) theta_extrapolation = 1.0 / pos_freqs theta_interpolation = 1.0 / (factor * pos_freqs) # Find correction range based on rotation counts # Inverse dimension formula to find dimension based on number of rotations def find_correction_dim(num_rotations, dim, base_val, max_pos): return (dim * math.log(max_pos / (num_rotations * 2 * math.pi))) / (2 * math.log(base_val)) low_dim = find_correction_dim(beta_fast, n_elem, base, original_max_seq_len) high_dim = find_correction_dim(beta_slow, n_elem, base, original_max_seq_len) # Apply truncation if specified if extra_config.get("truncate", True): low_dim = math.floor(low_dim) high_dim = math.ceil(high_dim) low_dim = max(low_dim, 0) high_dim = min(high_dim, n_elem // 2 - 1) # Create linear ramp factor for blending dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32) if low_dim == high_dim: high_dim += 0.001 # Prevent singularity linear_func = (dim_range - low_dim) / (high_dim - low_dim) ramp_func = torch.clamp(linear_func, 0.0, 1.0) # Blend extrapolation and interpolation frequencies # ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled) theta_extrapolation_factor = ramp_func theta = ( theta_interpolation * (1 - theta_extrapolation_factor) + theta_extrapolation * theta_extrapolation_factor ) elif "original_max_seq_len" in extra_config: # Llama3-style RoPE scaling orig_context_len = extra_config["original_max_seq_len"] low_freq_factor = extra_config["low_freq_factor"] high_freq_factor = extra_config["high_freq_factor"] wavelen = 2 * torch.pi / theta ratio = orig_context_len / wavelen smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) # Compute adjusted_theta without masked indexing adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta theta = adjusted_theta else: # Linear scaling fallback theta = theta / factor # Create position indices `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) # If `n_elem` is odd, the final dimension of `idx_theta` has size # `n_elem + 1`, so need to cut something off. # Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave # `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to # broadcasting. If we shorten `idx_theta`, unit tests comparing to # Hugging Face fail. # https://github.com/huggingface/transformers/issues/35233 if idx_theta.shape[-1] > n_elem > 1: idx_theta = idx_theta[..., :n_elem] # if rope_local_base_freq is given, have a separate rope value for local embedding # For now, we use default RoPE for local embedding if rope_local_base_freq is not None: local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) local_idx_theta = torch.outer(seq_idx, local_theta) local_idx_theta = local_idx_theta.repeat(1, 2) if local_idx_theta.shape[-1] > n_elem > 1: local_idx_theta = local_idx_theta[..., :n_elem] idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1) cos = torch.cos(idx_theta) * attention_scaling sin = torch.sin(idx_theta) * attention_scaling return cos, sin def batched_index_select(t, dim, idx): """index_select for batched index and unbatched t""" if idx.dim() == 1: return torch.index_select(t, dim, idx) *batch_shape, idx_size = idx.shape res = torch.index_select(t, dim, idx.reshape(-1)) # flat index # split out single batch idx res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) if dim > 0: # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors dims = [dim] + list(range(res.dim())) del dims[dim + 1] res = res.permute(dims) # unflatten batch dims res = res.view(*batch_shape, *res.shape[1:]) return res def batched_index_copy_(t, dim, idx, val): """Index copy for batched t, idx, val""" if t.device.type == "mps": # Normalize negative dimensions if dim < 0: dim = t.dim() + dim if idx.dim() == 1: idx_shape = [1] * val.dim() idx_shape[dim] = -1 idx_expanded = idx.view(*idx_shape) idx_expanded = idx_expanded.expand_as(val) t.scatter_(dim, idx_expanded, val) return t elif idx.dim() == 2: assert dim != 0, "Cannot index the batch dimension" batch_size = idx.size(0) idx_size = idx.size(1) assert batch_size == t.size(0) == val.size(0) idx_shape = [batch_size] + [1] * (val.dim() - 1) idx_shape[dim] = idx_size idx_expanded = idx.view(*idx_shape) idx_expanded = idx_expanded.expand_as(val) t.scatter_(dim, idx_expanded, val) return t else: raise NotImplementedError(f"idx.dim() == {idx.dim()} not supported") else: if idx.dim() == 1: return t.index_copy_(dim, idx, val) assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}" assert dim != 0, f"cannot index batch dim {dim=}" batch_size, idx_size = idx.shape assert batch_size == t.size(0) assert batch_size == val.size(0) # if we can view the batch and indexed dimensions together, we could # do index trickery. This is, sadly, not the case for kvcache so we # fall back to for loop for i in range(batch_size): unbatched_dim = dim if dim < 0 else dim - 1 t[i].index_copy_(unbatched_dim, idx[i], val[i]) return t def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """ Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch dimension. Args: x: Input tensor, `(B, ..., T, head_size)` cos: Cached cosines, `(B, T, head_size)` or `(1, T, head_size)` sin: Cached sines, `(B, T, head_size)` or `(1, T, head_size)` Returns: Encoded tensor, `(B, ..., T, head_size)` """ if cos.dim() != 3: raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") if cos.shape != sin.shape: raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") head_size_half = x.size(-1) // 2 x1 = x[..., :head_size_half] # (B, ..., T, head_size/2) x2 = x[..., head_size_half:] # (B, ..., T, head_size/2) rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size) dims_diff = x.dim() - cos.dim() if dims_diff > 0: # Ensure that shapes of `x`, `cos`, `sin` align new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] cos = cos.view(*new_shape) sin = sin.view(*new_shape) roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) def apply_rope_interleave(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary position embeddings with interleaved tensor layout. This version rearranges the input tensor to group even/odd indices separately before applying the standard RoPE rotation, matching HuggingFace's apply_rotary_pos_emb_interleave behavior. Args: x: Input tensor of shape (..., seq_len, head_dim) cos: Cosine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim) sin: Sine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim) Returns: Tensor with RoPE applied, same shape as input """ if cos.dim() != 3: raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") if cos.shape != sin.shape: raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") # Rearrange tensor to group even/odd indices: [x0,x1,x2,x3,...] -> [x0,x2,x4,...,x1,x3,x5,...] *batch_dims, d = x.shape x = x.view(*batch_dims, d // 2, 2).transpose(-1, -2).reshape(*batch_dims, d) # Standard rotation logic (same as apply_rope) head_size_half = x.size(-1) // 2 x1 = x[..., :head_size_half] x2 = x[..., head_size_half:] rotated = torch.cat((-x2, x1), dim=-1) # Auto-detect dimension mismatch and reshape cos/sin dims_diff = x.dim() - cos.dim() if dims_diff > 0: new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] cos = cos.view(*new_shape) sin = sin.view(*new_shape) roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: return torch.tanh(x / thresh) * thresh class KVCache(nn.Module): """ Buffers `k`, `v` have shape `(batch_size, n_query_groups, max_seq_length, head_size)`. """ def __init__( self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, is_sliding_window: bool = False, sliding_window_size: Optional[int] = None, ) -> None: super().__init__() self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) self.is_sliding_window = is_sliding_window self.sliding_window_size = sliding_window_size self.max_cache_len = k_shape[2] def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Writes new values `k` and `v` into the cache at the positions specified by `input_pos` along the sequence dimension (`max_seq_length`). The batch size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch size. Returns the full buffers, adjusted to the batch size `bs`. Args: input_pos: Position index, `(bs, T)` or `(T,)` k: New values, `(bs, n_query_groups, T, head_size)` v: New values, `(bs, n_query_groups, T, head_size)` Returns: k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)` """ # move the buffer to the activation dtype for when AMP is used if self.k.dtype != k.dtype: self.k = self.k.to(k.dtype) if self.v.dtype != v.dtype: self.v = self.v.to(v.dtype) # update the cache bs = k.size(0) if self.is_sliding_window: # Circular buffer for sliding window cache_positions = input_pos % self.max_cache_len k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k) v = batched_index_copy_(self.v[:bs, ...], -2, cache_positions, v) max_pos = input_pos.max().item() if max_pos < self.max_cache_len: k = k[:, :, : max_pos + 1, :] v = v[:, :, : max_pos + 1, :] else: # Standard KV cache (global attention) k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) return k, v def reset_parameters(self) -> None: torch.nn.init.zeros_(self.k) torch.nn.init.zeros_(self.v) def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) return torch.tril(ones).unsqueeze(0).unsqueeze(0) class RMSNorm(torch.nn.Module): """Root Mean Square Layer Normalization. Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim self.add_unit_offset = add_unit_offset def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() # NOTE: the original RMSNorm paper implementation is not equivalent norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) weight = (1 + self.weight) if self.add_unit_offset else self.weight return (x_normed * weight.float()).to(dtype=dtype) def reset_parameters(self) -> None: torch.nn.init.ones_(self.weight) ================================================ FILE: litgpt/parser_config.py ================================================ import sys from pathlib import Path from typing import List, Optional from litgpt.utils import CLI def parser_commands() -> List[str]: return [ "download", "chat", "finetune", "finetune_lora", "finetune_full", "finetune_adapter", "finetune_adapter_v2", "pretrain", "generate", "generate_full", "generate_adapter", "generate_adapter_v2", "generate_sequentially", "generate_speculatively", "generate_tp", "convert_to_litgpt", "convert_from_litgpt", "convert_pretrained_checkpoint", "merge_lora", "evaluate", "serve", ] def save_hyperparameters( function: callable, checkpoint_dir: Path, known_commands: Optional[List[str]] = None, ) -> None: """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" from jsonargparse import capture_parser # TODO: Make this more robust # This hack strips away the subcommands from the top-level CLI # to parse the file as if it was called as a script if known_commands is None: known_commands = parser_commands() known_commands = [(c,) for c in known_commands] for known_command in known_commands: unwanted = slice(1, 1 + len(known_command)) if tuple(sys.argv[unwanted]) == known_command: sys.argv[unwanted] = [] parser = capture_parser(lambda: CLI(function)) config = parser.parse_args() parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True) ================================================ FILE: litgpt/pretrain.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import math import pprint import time import warnings from dataclasses import asdict from datetime import timedelta from functools import partial from pathlib import Path from typing import Dict, Optional, Tuple, Union import lightning as L import torch import torch.nn as nn from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops from torch.utils.data import DataLoader from torchmetrics.aggregation import RunningMean from typing_extensions import Literal from litgpt import Tokenizer from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.config import name_to_config from litgpt.constants import _TORCH_EQUAL_2_7, _TORCH_EQUAL_2_8 from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP from litgpt.parser_config import save_hyperparameters from litgpt.types import LoggerChoice from litgpt.utils import ( CycleIterator, capture_hparams, check_nvlink_connectivity, choose_logger, chunked_cross_entropy, copy_config_files, extend_checkpoint_dir, find_resume_path, get_default_supported_precision, init_out_dir, instantiate_torch_optimizer, num_parameters, parse_devices, reset_parameters, save_config, ) def setup( model_name: str, model_config: Optional[Config] = None, out_dir: Path = Path("out/pretrain"), precision: Literal["bf16-true", "bf16-mixed", "32-true", None] = None, initial_checkpoint_dir: Optional[Path] = None, resume: Union[bool, Literal["auto"], Path] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, global_batch_size=512, micro_batch_size=4, max_tokens=int(3e12), # 3 trillion max_norm=1.0, min_lr=4e-5, lr_warmup_steps=2000, tie_embeddings=False, ), eval: EvalArgs = EvalArgs(interval=1000, max_iters=100), log: LogArgs = LogArgs(), optimizer: Union[str, Dict] = "AdamW", devices: Union[int, str] = "auto", num_nodes: int = 1, tokenizer_dir: Optional[Path] = None, logger_name: LoggerChoice = "tensorboard", seed: int = 42, ): """Pretrain a model. Arguments: model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Use "list" to list the supported models. model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with ``model_config``. Overrides the `model_name` if specified. out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in /teamspace/jobs//share. precision: The precision to use for finetuning. Determines a compatible precision setting by default. initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from. Useful for continued pretraining. Mutually exclusive with ``resume``. resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. devices: How many devices/GPUs to use. Uses all GPUs by default. num_nodes: How many nodes the code is being run on. tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data module require this. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ if model_name == "list": available_models = "\n".join(sorted(name_to_config)) print(f"Available values:\n{available_models}") quit() if initial_checkpoint_dir is not None: initial_checkpoint_dir = extend_checkpoint_dir(initial_checkpoint_dir) if tokenizer_dir is not None: tokenizer_dir = extend_checkpoint_dir(tokenizer_dir) if model_config is None: # Support both model_name options: meta-llama/Meta-Llama-3-8B & Meta-Llama-3-8B try: model_config = Config.from_name(model_name) except ValueError: print(f"Model name {model_name} is not supported.\n") available_models = "\n".join(sorted(name_to_config)) print(f"Available values:\n{available_models}") quit() hparams = capture_hparams() data = TinyLlama() if data is None else data config = Config.from_name(model_name) if model_config is None else model_config precision = precision or get_default_supported_precision(training=True) devices = parse_devices(devices) out_dir = init_out_dir(out_dir) # in case the dataset requires the Tokenizer tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None logger = choose_logger( logger_name, out_dir, name=f"pretrain-{config.name}", resume=bool(resume), log_interval=train.log_interval, log_args=asdict(log), ) if devices * num_nodes > 1: strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD") else: strategy = "auto" fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=[logger]) if torch.cuda.is_available() and devices > 1: check_nvlink_connectivity(fabric) fabric.launch() fabric.print(pprint.pformat(hparams)) if logger_name in ("tensorboard", "wandb", "mlflow"): fabric.logger.log_hyperparams(hparams) main( fabric=fabric, devices=devices, num_nodes=num_nodes, seed=seed, initial_checkpoint_dir=initial_checkpoint_dir, resume=resume, config=config, data=data, out_dir=out_dir, tokenizer_dir=tokenizer_dir, tokenizer=tokenizer, train=train, eval=eval, optimizer=optimizer, ) def main( fabric: L.Fabric, devices: int, seed: int, initial_checkpoint_dir: Optional[Path], resume: Union[bool, Literal["auto"], Path], config: Config, data: DataModule, out_dir: Path, tokenizer_dir: Optional[Path], tokenizer: Optional[Tokenizer], train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], num_nodes: int = 1, ) -> None: validate_args(train, eval, initial_checkpoint_dir, resume) if fabric.global_rank == 0: out_dir.mkdir(parents=True, exist_ok=True) fabric.seed_everything(seed) # same seed for every process to init model (FSDP) t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd) if train.tie_embeddings: model.transformer.wte.weight = model.lm_head.weight if train.max_seq_length: model.max_seq_length = train.max_seq_length fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters: {num_parameters(model):,}") model = torch.compile(model) model = fabric.setup(model) extra_kwargs = {"fused": fabric.device.type == "cuda"} optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), **extra_kwargs) optimizer = fabric.setup_optimizers(optimizer) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length) train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) if initial_checkpoint_dir: fabric.load_raw(initial_checkpoint_dir / "lit_model.pth", model) state = { "model": model, "optimizer": optimizer, "train_dataloader": train_dataloader, "iter_num": 0, "step_count": 0, } resume = find_resume_path(resume, out_dir) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) train_time = time.perf_counter() # work around PyTorch issue https://github.com/pytorch/pytorch/issues/152162 # which does not like the lazy initialization to be called in dynamo. # TODO: Happens with PyTorch 2.7+ if ( (_TORCH_EQUAL_2_7 or _TORCH_EQUAL_2_8) and (model._forward_module.__class__.__name__ == "OptimizedModule") and (model._forward_module._orig_mod.__class__.__name__ == "FullyShardedDataParallel") ): from torch.distributed.fsdp._runtime_utils import _root_pre_forward _root_pre_forward(model._forward_module._orig_mod, model._forward_module._orig_mod, [], {}) fit( fabric=fabric, devices=devices, num_nodes=num_nodes, state=state, train_dataloader=train_dataloader, val_dataloader=val_dataloader, out_dir=out_dir, tokenizer_dir=tokenizer_dir, train=train, eval=eval, ) # Save final checkpoint save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth") total_tokens = state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size # Print formatted output separator = "-" * 40 fabric.print(separator) fabric.print("| Performance") fabric.print(f"| - Total tokens : {total_tokens:,}") fabric.print(f"| - Training Time : {(time.perf_counter() - train_time):.2f} s") fabric.print(f"| - Tok/sec : {total_tokens / train_time:.2f} tok/s") fabric.print("| " + "-" * 40) if fabric.device.type == "cuda": memory_used = torch.cuda.max_memory_allocated() / 1e9 fabric.print("| Memory Usage") fabric.print(f"| - Memory Used : {memory_used:.2f} GB") fabric.print(separator) def fit( fabric: L.Fabric, devices: int, state: dict, train_dataloader: DataLoader, val_dataloader: DataLoader, out_dir: Path, tokenizer_dir: Optional[Path], train: TrainArgs, eval: EvalArgs, num_nodes: int = 1, ) -> None: model = state["model"] optimizer = state["optimizer"] if eval.initial_validation: val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters) val_loss = f"{val_loss:.3f}" else: fabric.print("Verifying settings ...") validate(fabric, model, val_dataloader, max_iters=2, verbose=False) # sanity check val_loss = "n/a" throughput = ThroughputMonitor(fabric, window_size=5) with torch.device("meta"): meta_model = GPT(model.config) x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length)) model_fwd = lambda: meta_model(x) # noqa: F821 model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821 measured_flops = measure_flops(meta_model, model_fwd, model_loss) fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x max_tokens_per_device = train.max_tokens // fabric.world_size tokens_per_iter = train.micro_batch_size * model.max_seq_length max_iters = max_tokens_per_device // tokens_per_iter log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes) initial_iter = state["iter_num"] train_iterator = CycleIterator(train_dataloader) running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to( fabric.device ) fabric.barrier() total_t0 = time.perf_counter() warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader) for train_data in train_iterator: if state["iter_num"] >= max_iters: break # determine and set the learning rate for this iteration lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr) for param_group in optimizer.param_groups: param_group["lr"] = lr state["iter_num"] += 1 iter_t0 = time.perf_counter() input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long() targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long() is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) loss = chunked_cross_entropy(logits, targets) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) running_loss.update(loss.detach()) if not is_accumulating: fabric.clip_gradients(model, optimizer, max_norm=train.max_norm) optimizer.step() optimizer.zero_grad() state["step_count"] += 1 if state["iter_num"] % log_iter_interval == 0: loss = running_loss.compute().item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=(t1 - total_t0), flops=(measured_flops * log_iter_interval), batches=state["iter_num"], samples=(state["iter_num"] * train.micro_batch_size), lengths=(state["iter_num"] * train.micro_batch_size * model.max_seq_length), ) metrics = { "loss": loss, "iter": state["iter_num"], "step": state["step_count"], "epoch": train_iterator.epoch, "iter_time": t1 - iter_t0, "remaining_time": ( (t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"]) ), "tokens": state["iter_num"] * train.micro_batch_size * model.max_seq_length, "total_tokens": (state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size), "learning_rate": lr, } if isinstance(val_loss, float): val_loss = f"{val_loss:.3f}" fabric.print( f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |" f" loss train: {metrics['loss']:.3f}," f" val: {val_loss} |" f" iter time: {metrics['iter_time'] * 1000:.2f} ms" f"{' (step)' if not is_accumulating else ''}" f" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}" ) throughput_metrics = throughput.compute() metrics.update(throughput_metrics) fabric.log_dict(metrics, step=state["iter_num"] - 1) if val_dataloader is not None and not is_accumulating and state["step_count"] % eval.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters) val_loss = val_loss.item() td = time.perf_counter() - t0 fabric.print(f"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms") metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=state["iter_num"] - 1) fabric.barrier() if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0: save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth") # Final validation if eval.final_validation: val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters) metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} fabric.log_dict(metrics, step=state["iter_num"]) fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}") @torch.no_grad() def validate( fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int, verbose: bool = True ) -> torch.Tensor: fabric.barrier() if verbose: fabric.print("Validating ...") model.eval() losses = [] for k, batch in enumerate(val_dataloader): if k >= max_iters: break input_ids = batch[:, 0 : model.max_seq_length].contiguous().long() targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long() logits = model(input_ids) loss = chunked_cross_entropy(logits, targets) losses.append(loss) val_loss = torch.stack(losses).mean() model.train() fabric.barrier() return val_loss def get_dataloaders( fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int ) -> Tuple[DataLoader, DataLoader]: data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size) with fabric.rank_zero_first(): data.prepare_data() data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() return train_dataloader, val_dataloader # learning rate decay scheduler (cosine with linear warmup) def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float: # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > max_iters, return min learning rate if it > max_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) def initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None: """GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745).""" # Adapted from https://github.com/jzhang38/TinyLlama def init_weights(module, std): nn.init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: nn.init.zeros_(module.bias) for mod in model.modules(): if isinstance(mod, (nn.Embedding, nn.Linear)): mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd)) # need a separate loop because `mod.proj` below is a `nn.Linear` too for mod in model.modules(): if isinstance(mod, (LLaMAMLP, CausalSelfAttention)): mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer)) if not isinstance(fabric.strategy, FSDPStrategy): reset_parameters(model) def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file): model = state["model"] checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") fabric.save(checkpoint_file, state) if fabric.global_rank == 0: save_hyperparameters(setup, checkpoint_file.parent) if tokenizer_dir is not None: copy_config_files(tokenizer_dir, checkpoint_file.parent) save_config(model.config, checkpoint_file.parent) def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None: issues = [] unsupported = [(train, ["epochs"]), (eval, ["max_new_tokens"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") if train.max_steps is not None: warnings.warn( "`train.max_steps` is intended for profiling or debug runs only. " "For full pretraining runs, prefer `train.max_tokens` or `train.max_time`.", UserWarning, ) required = [(train, ["max_tokens", "max_norm"])] for args, names in required: for name in names: if getattr(args, name) is None: issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}") if initial_checkpoint_dir and resume: issues.append("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.") if issues: raise ValueError("\n".join(issues)) ================================================ FILE: litgpt/prompts.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import importlib import re from abc import abstractmethod from json import dumps from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import yaml from litgpt.config import Config if TYPE_CHECKING: from litgpt import Tokenizer class PromptStyle: """Base interface for prompt styles.""" @abstractmethod def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return prompt def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ([tokenizer.eos_id],) @classmethod def from_name(cls, name: str) -> "PromptStyle": return prompt_styles[name]() @classmethod def from_config(cls, config: Config) -> "PromptStyle": return model_name_to_prompt_style(config.name) class Default(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return prompt def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ([tokenizer.eos_id],) class Alpaca(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: if kwargs.get("input"): sys_prompt = sys_prompt or ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" ) return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Input:\n{kwargs['input']}\n\n### Response:\n" sys_prompt = sys_prompt or ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" ) return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n" class FLAN(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" ) return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n" class Longform(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" ) return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n" class StableLMAlpha(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or ( "# StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language" " model developed by StabilityAI.\n- StableLM is excited to be able to help the user, but will refuse to do" " anything that could be considered harmful to the user.\n- StableLM is more than just an information" " source, StableLM is also able to write poetry, short stories, and make jokes.\n- StableLM will refuse to" " participate in anything that could harm a human." ) return f"<|SYSTEM|>{sys_prompt}<|USER|>{prompt}<|ASSISTANT|>" def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("<|SYSTEM|>")], [tokenizer.token_to_id("<|ASSISTANT|>")], [tokenizer.token_to_id("<|USER|>")], ) class StableLMZephyr(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n" class Falcon(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"{prompt}\nAnswer:" def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], # the model rarely emits the eos token and instead outputs newlines, but we cannot use them # to stop or else things like code generation wouldn't work [tokenizer.token_to_id("User"), tokenizer.token_to_id(":")], [193, tokenizer.token_to_id("User")], # 193: '\n' ) class Falcon3(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n" def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("<|endoftext|>")], ) class Llama2FunctionCalling(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: # Has to be before the llama config b_func, e_func = "", "\n\n" b_inst, e_inst = "[INST]", "[/INST]" b_sys, e_sys = "<>\n", "\n<>\n\n" # This is an example for how to format functions for the model function_metadata = { "function": "search_bing", "description": ( "Search the web for content on Bing. This allows users to search online/the internet/the web for" " content." ), "arguments": [{"name": "query", "type": "string", "description": "The search query string"}], } system_prompt = sys_prompt or ( "You are a helpful, respectful and honest assistant. Always answer as helpfully as" "possible. Your only response should be JSON formatted functions" ) # replace the curly braces with double curly braces to escape them function_list = dumps(function_metadata).replace("{", "{{").replace("}", "}}") return ( f"{b_func}{function_list.strip()}{e_func}{b_inst}{b_sys}{system_prompt.strip()}{e_sys}{prompt}{e_inst}\n\n" ) class Llama2(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: b_inst, e_inst = "[INST]", "[/INST]" b_sys, e_sys = "<>\n", "\n<>\n\n" sys_prompt = sys_prompt or ( "You are a helpful, respectful and honest assistant. Always answer as helpfully as" " possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist," " toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and" " positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why" " instead of answering something not correct. If you don't know the answer to a question, please don't" " share false information." ) return f"{b_inst} {b_sys}{sys_prompt}{e_sys} {prompt} {e_inst} " class Llama3(PromptStyle): def apply( self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str ) -> str: default_system_prompt = sys_prompt or "You are a helpful assistant." # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229 if isinstance(prompt, str): return ( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" f"{default_system_prompt}<|eot_id|>" # No newline "<|start_header_id|>user<|end_header_id|>\n\n" f"{prompt}<|eot_id|>" # No newline "<|start_header_id|>assistant<|end_header_id|>\n\n" ) elif isinstance(prompt, list): def encode_header(role: str) -> List[str]: return [f"<|start_header_id|>{role}<|end_header_id|>\n\n"] def encode_message(message: Dict[str, str]) -> List[str]: tokens = encode_header(message["role"]) # NOTE: Meta stripped this. I'm not sure I agree, but who am I to argue? tokens.append(message["content"].strip()) tokens.append("<|eot_id|>") return tokens def has_system_prompt(messages: List[Dict[str, str]]) -> bool: return messages[0].get("role", "") == "system" if len(messages) else False tokens = ["<|begin_of_text|>"] if not has_system_prompt(prompt): tokens.extend(encode_message({"role": "system", "content": default_system_prompt})) for i, message in enumerate(prompt): if i != 0 and message["role"] == "system": raise ValueError("'system' role is only allowed at the beginning of the conversation list.") if message["role"] not in ["assistant", "user", "system"]: raise ValueError( f"Unknown role: '{message['role']}'. Supported roles are 'assistant', 'user', and 'system'." ) tokens.extend(encode_message(message)) tokens.extend(encode_header("assistant")) return "".join(tokens) else: raise ValueError(f"Unsupported prompt type: {type(prompt)}") def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("<|eot_id|>")], ) class R1Base(PromptStyle): def apply( self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str ) -> str: default_system_prompt = sys_prompt or "" bos_token = "<|begin▁of▁sentence|>" eos_token = "" if isinstance(prompt, str): return f"{default_system_prompt}<|User|>{prompt}<|Assistant|>" # Prepares for assistant response elif isinstance(prompt, list): def encode_message(message: Dict[str, str]) -> str: role = message["role"] content = message["content"].strip() if role == "system": return content # System prompt is prepended at the start elif role == "user": return f"<|User|>{content}" elif role == "assistant": return f"<|Assistant|>{content}{eos_token}" else: raise ValueError(f"Unknown role: '{role}'. Supported roles are 'assistant', 'user', and 'system'.") # Extract system prompt (if any) system_prompt = "" if prompt[0].get("role") == "system": system_prompt = prompt[0]["content"] prompt = prompt[1:] # Remove system message from the list # Construct the formatted prompt formatted_prompt = system_prompt for message in prompt: formatted_prompt += encode_message(message) formatted_prompt += "<|Assistant|>" # Prepares for assistant response return formatted_prompt else: raise ValueError(f"Unsupported prompt type: {type(prompt)}") def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("<|end▁of▁sentence|>")], ) class FreeWilly2(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or "This is a system prompt, please behave and help the user." return f"### System:\n{sys_prompt}\n\n### User:\n{prompt}\n\n### Assistant:\n" class Platypus(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"### Instruction:\n\n{prompt}\n\n### Response:\n" class StableCode(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"###Instruction\n{prompt}###Response\n" class CodeLlama(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: # for CodeLLama, we don't set a default system prompt, but it is supported: # https://huggingface.co/blog/codellama#conversational-instructions # Mistral does not: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format b_inst, e_inst = "[INST]", "[/INST]" if sys_prompt: b_sys, e_sys = "<>\n", "\n<>\n\n" return f"{b_inst} {b_sys}{sys_prompt}{e_sys}{prompt} {e_inst}" return f"{b_inst} {prompt} {e_inst}" class Phi1(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"{prompt}\n\nAnswer:" def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")], [198, tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")], # the model rarely emits the eos token and instead outputs newlines, but we cannot use them # to stop or else things like code generation wouldn't work # [198, 198], # '\n', '\n' ) class Phi2(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"Instruct: {prompt}\nOutput:" class Phi3(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or "You are a helpful assistant." return f"<|system|>\n{sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n" class Phi4(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: res = "" if sys_prompt: res += f"<|im_start|>system<|im_sep|>{sys_prompt}<|im_end|>" res += f"<|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>" return res class Phi4Reasoning(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = ( sys_prompt or "You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: {Thought section} {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:" ) return f"<|im_start>system<|im_sep|>{sys_prompt}<|im_end|><|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>" class Phi4Mini(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: res = "" if sys_prompt: res += f"<|system|>{sys_prompt}<|end|>" res += f"<|user|>{prompt}<|end|><|assistant|>" return res class Phi4MiniReasoning(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or "Your name is Phi, an AI math expert developed by Microsoft." return f"<|system|>{sys_prompt}<|end|><|user|>{prompt}<|end|><|assistant|>" class TinyLlama(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or "You are a friendly chatbot who always gives helpful, detailed, and polite answers." return f"<|system|>\n{sys_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n" class Gemma(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"user\n{prompt}\nmodel\n" class OLMo(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" class ChatML(PromptStyle): def __init__(self, system_message: Optional[str] = None): self.system_message = system_message def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: sys_prompt = sys_prompt or self.system_message return ( f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" ) class Qwen2_5(ChatML): def __init__(self): super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") class Qwen2_5_Math(ChatML): def __init__(self): super().__init__("Please reason step by step, and put your final answer within \\boxed{}.") class QwQ(ChatML): def __init__(self): super().__init__( "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." ) class Qwen3(ChatML): def __init__(self): super().__init__() class SmolLM2(ChatML): def __init__(self): super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face") class Salamandra(ChatML): def __init__(self): super().__init__( "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit." ) # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { # Dataset-specific prompt styles "alpaca": Alpaca, "flan": FLAN, "longform": Longform, # Model-specific prompt styles "stablelm-alpha": StableLMAlpha, "stablelm-zephyr": StableLMZephyr, "falcon": Falcon, "llama2-function-calling": Llama2FunctionCalling, "llama2": Llama2, "freewilly2": FreeWilly2, "platypus": Platypus, "stablecode": StableCode, "codellama": CodeLlama, "phi-1": Phi1, "phi-2": Phi2, "phi-3": Phi3, "phi-4": Phi4, "phi-4-reasoning": Phi4Reasoning, "phi-4-mini": Phi4Mini, "phi-4-mini-reasoning": Phi4MiniReasoning, "tinyllama": TinyLlama, "gemma": Gemma, "llama3": Llama3, "olmo": OLMo, "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, "qwen3": Qwen3, "smollm2": SmolLM2, "salamandra": Salamandra, } def model_name_to_prompt_style(model_name: str) -> PromptStyle: if re.search(r"stablelm-tuned-alpha", model_name): return StableLMAlpha() if re.search(r"stablelm-zephyr-3b", model_name): return StableLMZephyr() if re.search("stablecode-instruct", model_name): return StableCode() if re.search(r"Falcon3.*-Instruct", model_name): return Falcon3() if re.search(r"falcon.*-instruct", model_name): return Falcon() if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name): return Llama2FunctionCalling() if re.search("Llama-2.*-chat", model_name): return Llama2() if re.search("Llama-3.*-Instruct", model_name): return Llama3() if re.search("Llama-3.*-Instruct-*", model_name): return Llama3() if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name): return Llama3() if re.search("R1", model_name): return R1Base() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): return Platypus() if re.search("CodeLlama|Mi[sx]tral.*Instruct", model_name): return CodeLlama() if re.search("phi-1", model_name): return Phi1() if re.search("phi-2", model_name): return Phi2() if re.search("Phi-3", model_name): return Phi3() if re.search("Phi-4-reasoning", model_name): return Phi4Reasoning() if re.search("Phi-4-mini-reasoning", model_name): return Phi4MiniReasoning() if re.search("Phi-4-mini", model_name): return Phi4Mini() if re.search("phi-4", model_name): return Phi4() if re.search(r"tiny-llama.*chat", model_name): return TinyLlama() if re.search(r"(Code)?Gemma.*-it", model_name): return Gemma() if re.search(r"OLMo.*-hf", model_name): return OLMo() if re.search(r"Qwen2\.5-Math-.*", model_name): return Qwen2_5_Math() if re.search(r"Qwen2\.5-.*", model_name): return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() if re.search(r"Qwen3-.*", model_name): return Qwen3() if re.search(r"SmolLM2.*-Instruct", model_name): return SmolLM2() if re.search(r"salamandra-.*-instruct", model_name): return Salamandra() return Default() def save_prompt_style(style: Union[str, PromptStyle], checkpoint_dir: Path) -> None: style = PromptStyle.from_name(style) if isinstance(style, str) else style cls = type(style) # Allow saving the full module path for user-defined prompt classes config = {"class_path": f"{cls.__module__}.{cls.__name__}"} with open(checkpoint_dir / "prompt_style.yaml", "w", encoding="utf-8") as file: yaml.dump(config, file) def load_prompt_style(checkpoint_dir: Path) -> PromptStyle: with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file: config = yaml.safe_load(file) # Support loading the full module path for user-defined prompt classes full_module_path, cls_name = config["class_path"].rsplit(".", 1) module = importlib.import_module(full_module_path) cls = getattr(module, cls_name) return cls() def has_prompt_style(checkpoint_dir: Path) -> bool: return (checkpoint_dir / "prompt_style.yaml").is_file() ================================================ FILE: litgpt/scripts/__init__.py ================================================ ================================================ FILE: litgpt/scripts/convert_hf_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import gc import json import os import re import warnings from collections import defaultdict from functools import partial from pathlib import Path from pprint import pprint from typing import Dict, List, Optional, Tuple, Union import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor from safetensors.torch import load_file as load_safetensors from tqdm import tqdm from litgpt.config import Config from litgpt.utils import ( extend_checkpoint_dir, incremental_save, lazy_load, save_config, ) def copy_weights_gpt_neox( config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "gpt_neox.embed_in.weight": "transformer.wte.weight", "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias", "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, "gpt_neox.layers.{}.attention.bias": None, "gpt_neox.layers.{}.attention.masked_bias": None, "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", "embed_out.weight": "lm_head.weight", } if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) for from_name, param in hf_weights.items(): name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template] if to_name is None: continue to_name = to_name.format(layer_idx) param = load_param(param, from_name, dtype, verbose=debug_mode) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_falcon( config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "transformer.word_embeddings.weight": "transformer.wte.weight", "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", "transformer.ln_f.bias": "transformer.ln_f.bias", "transformer.ln_f.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size if "7b" in config.name: weight_map.update( { "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", } ) elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", } ) else: raise NotImplementedError if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) for from_name, param in hf_weights.items(): name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template].format(layer_idx) param = load_param(param, from_name, dtype, verbose=debug_mode) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_hf_llama( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name == "LLaMAMoE": weight_map.update( { "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{}.mlp.gate.weight", "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: raise NotImplementedError if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map[name_template] param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_gemma_2( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight", "model.layers.{}.pre_feedforward_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight", "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map[name_template] param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_gemma_3( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, config: Optional[Config] = None, ) -> None: GEMMA3_LANGUAGE_MODEL_PREFIX = ( "model.language_model" if any(k.startswith("model.language_model") for k in hf_weights) else "language_model.model" ) GEMMA3_VISION_MODEL_PREFIX = ( "model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower" ) GEMMA3_MM_PROJECTOR_PREFIX = ( "model.multi_modal_projector" if any(k.startswith("model.multi_modal_projector") for k in hf_weights) else "multi_modal_projector" ) weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight", "model.layers.{}.pre_feedforward_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight", "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", } if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) # gemma3 4b+ are multimodel models, but we are only loading the text weights is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights) if is_multimodal: warnings.warn("For Gemma3 models only the text component is supported.") new_weight_map = dict() prefix = "model" for k, v in weight_map.items(): if k.startswith(prefix): k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :] new_weight_map[k] = v weight_map = new_weight_map for from_name, param in hf_weights.items(): if from_name.startswith(GEMMA3_VISION_MODEL_PREFIX) or from_name.startswith(GEMMA3_MM_PROJECTOR_PREFIX): continue name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map.get(name_template) param = load_param(param, from_name, dtype, verbose=debug_mode) # in multimodal models, the text weights are the first part of the weights if is_multimodal and to_name == "transformer.wte.weight" and config is not None: param = param[: config.vocab_size] if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_phi( config: Config, qkv_weights: dict, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: if any(layer_name.startswith(("layers.", "transformer.")) for layer_name in hf_weights): raise ValueError( "You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'" ) weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.q_proj.bias": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.k_proj.bias": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.v_proj.bias": None, "model.layers.{}.self_attn.dense.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.dense.bias": "transformer.h.{}.attn.proj.bias", "model.layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight", "model.layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias", "model.layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight", "model.layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias", "model.final_layernorm.weight": "transformer.ln_f.weight", "model.final_layernorm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", "lm_head.bias": "lm_head.bias", } if config.name.startswith(("Phi-3", "phi-4", "Phi-4")): weight_map.update( { "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight", "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", "model.norm.weight": "transformer.ln_f.weight", } ) if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, layer_idx = layer_template(from_name) param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param elif from_name.endswith("gate_up_proj.weight"): weight = load_param(param, f"layer {layer_idx} gate_up_proj", dtype, verbose=debug_mode) fc_1, fc_2 = weight.chunk(2, dim=0) state_dict[f"transformer.h.{layer_idx}.mlp.fc_1.weight"] = fc_1 state_dict[f"transformer.h.{layer_idx}.mlp.fc_2.weight"] = fc_2 continue to_name = weight_map[name_template] if to_name is None: continue to_name = to_name.format(layer_idx) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict and config.name.startswith("Phi-4"): state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_qwen_2_5( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.q_proj.bias": None, "model.layers.{}.self_attn.k_proj.bias": None, "model.layers.{}.self_attn.v_proj.bias": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map[name_template] param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_olmo2( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight", "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.post_attention_norm.bias", "model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: raise NotImplementedError if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map[name_template] param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def copy_weights_qwen_3( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name == "LLaMAMoE": weight_map.update( { "model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", "model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", "model.layers.{}.mlp.experts.{}.down_proj.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", "model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight", } ) elif config.mlp_class_name == "LLaMAMLP": weight_map.update( { "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: raise NotImplementedError if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) for from_name, param in hf_weights.items(): name_template, *ids = layer_template(from_name, num_matches=2) to_name = weight_map[name_template] param = load_param(param, from_name, dtype, verbose=debug_mode) if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) weight_name, weight_type = from_name.split(".")[-2:] qkv[weight_type][weight_name] = param if to_name is None: continue to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if progress_per_file is not None: pbar.update(progress_per_file) if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: # qkv is split across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...] """ q_per_kv = config.n_head // config.n_query_groups qs = [] ks = [] vs = [] for chunk in torch.chunk(param, config.n_query_groups): split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) qs.append(split[0]) ks.append(split[1]) vs.append(split[2]) q = torch.cat(qs) k = torch.cat(ks) v = torch.cat(vs) return torch.cat((q, k, v)) def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: pattern = r"\.(\d+)\." if not (search_res := re.findall(pattern, layer_name)): return layer_name, -1 layer_name_template = re.sub(pattern, ".{}.", layer_name, count=num_matches) return layer_name_template, *(int(x) for x in search_res[:num_matches]) def load_param( param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool = False ) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` if verbose: print(f"Loading {name!r} into RAM") param = param._load_tensor() if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: if verbose: print(f"Converting {name!r} from {param.dtype} to {dtype}") param = param.to(dtype) return param @torch.inference_mode() def convert_hf_checkpoint( checkpoint_dir: Path, *, model_name: Optional[str] = None, dtype: Optional[str] = None, debug_mode: Optional[bool] = False, ) -> None: """ Convert a Hugging Face Transformers checkpoint into a LitGPT compatible checkpoint. Arguments: checkpoint_dir: Where to save the downloaded files. model_name: The existing config name to load. This is useful to download alternative weights of existing architectures. dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the dtype they are downloaded in. debug_mode: Prints the individual layers being loaded instead of a progress bar, which can be useful when developing and adding new models to LitGPT. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) if model_name is None: model_name = checkpoint_dir.name if dtype is not None: dtype = getattr(torch, dtype) config = Config.from_name(model_name) save_config(config, checkpoint_dir) if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, config) elif model_name.lower().startswith("gemma-2"): qkv_weights = {} copy_fn = partial(copy_weights_gemma_2, qkv_weights) elif model_name.lower().startswith("gemma-3"): qkv_weights = {} copy_fn = partial(copy_weights_gemma_3, qkv_weights, config=config) elif model_name.lower().startswith("phi"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_phi, config, qkv_weights) elif model_name.lower().startswith(("qwen2.5", "qwq")): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights) elif model_name.lower().startswith("olmo-2-"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_olmo2, config, qkv_weights) elif model_name.lower().startswith("qwen3"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_3, config, qkv_weights) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) else: copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} # Load the json file containing weight mapping pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" model_safetensor_map_json_path = checkpoint_dir / "model.safetensors.index.json" if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file with open(pytorch_bin_map_json_path, encoding="utf-8") as json_map: bin_index = json.load(json_map) bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} elif model_safetensor_map_json_path.is_file(): with open(model_safetensor_map_json_path, encoding="utf-8") as json_map: bin_index = json.load(json_map) bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} else: bin_files = set(checkpoint_dir.glob("*.bin")) | set(checkpoint_dir.glob("*.safetensors")) # some checkpoints serialize the training arguments bin_files = {f for f in bin_files if f.name != "training_args.bin"} if not bin_files: raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files") with incremental_save(checkpoint_dir / "lit_model.pth") as saver: # for checkpoints that split the QKV across several files, we need to keep all the bin files # open, so we use `ExitStack` to close them all together at the end if not debug_mode: # Using tqdm progress bar when not in debug mode total_size = max(1, sum(os.path.getsize(bin_file) for bin_file in bin_files)) total_progress = 100 with tqdm( total=total_progress, desc="Initializing", bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}", ) as pbar: for bin_file in sorted(bin_files): pbar.set_description(f"Loading weights: {bin_file.name}") current_file_size = os.path.getsize(bin_file) progress_per_file = (current_file_size / total_size) * total_progress hf_weights = ( load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file) ) copy_fn( sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode, ) gc.collect() if pbar.n < total_progress: pbar.update(total_progress - pbar.n) pbar.close() else: # Handling files without progress bar in debug mode for bin_file in sorted(bin_files): hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file) copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode) print(f"Saving converted checkpoint to {checkpoint_dir}") saver.save(sd) ================================================ FILE: litgpt/scripts/convert_lit_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import gc from collections import defaultdict from functools import partial from pathlib import Path from pprint import pprint from typing import Dict, Optional, Union import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor from litgpt import Config from litgpt.scripts.convert_hf_checkpoint import layer_template, load_param from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load def copy_weights_falcon( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", "transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "transformer.ln_f.bias", "transformer.ln_f.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size if "7b" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", } ) else: raise NotImplementedError for from_name, param in lit_weights.items(): name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template].format(layer_idx) param = load_param(param, from_name, None) if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", "transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", "transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", "lm_head.weight": "embed_out.weight", } for from_name, param in lit_weights.items(): name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template].format(layer_idx) param = load_param(param, from_name, None) if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_llama( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = False, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name == "LLaMAMoE": weight_map.update( { "transformer.h.{}.mlp.gate.weight": "model.layers.{}.block_sparse_moe.gate.weight", "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.block_sparse_moe.experts.{}.w1.weight", "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.block_sparse_moe.experts.{}.w3.weight", "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.block_sparse_moe.experts.{}.w2.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gemma_2( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = True, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.post_attention_norm.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.pre_feedforward_layernorm.weight", "transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", } for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gemma_3( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = True, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.post_attention_norm.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.pre_feedforward_layernorm.weight", "transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", "transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight", "transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight", } for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_phi( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.dense.weight", "transformer.h.{}.attn.proj.bias": "model.layers.{}.self_attn.dense.bias", "transformer.h.{}.mlp.fc.weight": "model.layers.{}.mlp.fc1.weight", "transformer.h.{}.mlp.fc.bias": "model.layers.{}.mlp.fc1.bias", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.fc2.weight", "transformer.h.{}.mlp.proj.bias": "model.layers.{}.mlp.fc2.bias", "transformer.ln_f.weight": "model.final_layernorm.weight", "transformer.ln_f.bias": "model.final_layernorm.bias", "lm_head.weight": "lm_head.weight", "lm_head.bias": "lm_head.bias", } if config.name.lower().startswith(("phi-3", "phi-4")): weight_map.update( { "transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.ln_f.weight": "model.norm.weight", } ) gate_up_proj_weights = defaultdict(dict) for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and config.name.startswith("Phi-4"): continue name_template, layer_idx = layer_template(from_name) param = load_param(param, from_name, None) if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): if config.name.lower().startswith(("phi-3", "phi-4")): to_names = (weight_map[name_template].format(layer_idx),) params = (param,) else: weight_type = from_name.split(".")[-1] # weight or bias to_names = ( f"model.layers.{{}}.self_attn.q_proj.{weight_type}".format(layer_idx), f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) elif from_name.endswith((".fc_1.weight", ".fc_2.weight")): weight = load_param(param, from_name, None) weight_name = from_name.split(".")[-2] gate_up_proj_weights[layer_idx][weight_name] = weight else: to_names = (weight_map[name_template].format(layer_idx),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if config.name.lower().startswith(("phi-3", "phi-4")): for layer_idx in list(gate_up_proj_weights): fc_1_weight = gate_up_proj_weights[layer_idx]["fc_1"] fc_2_weight = gate_up_proj_weights[layer_idx]["fc_2"] weight = torch.concat([fc_1_weight, fc_2_weight], dim=0) layer_name = f"model.layers.{layer_idx}.mlp.gate_up_proj.weight" state_dict[layer_name] = weight del gate_up_proj_weights[layer_idx] def copy_weights_qwen_2_5( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = False, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", } for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): weight_type = from_name.split(".")[-1] # weight or bias to_names = ( "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_olmo2( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = False, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight", "transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", "transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_qwen_3( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], untie_weights: bool = False, saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight", "transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", } if config.mlp_class_name == "LLaMAMoE": weight_map.update( { "transformer.h.{}.mlp.gate.weight": "model.layers.{}.mlp.gate.weight", "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.mlp.experts.{}.gate_proj.weight", "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.mlp.experts.{}.up_proj.weight", "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.mlp.experts.{}.down_proj.weight", } ) elif config.mlp_class_name == "LLaMAMLP": weight_map.update( { "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError for from_name, param in lit_weights.items(): if from_name == "lm_head.weight" and untie_weights: continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) if from_name.endswith(".attn.qkv.weight"): weight_type = from_name.split(".")[-1] # weight or bias to_names = ( "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), ) params = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] """ q, k, v = param.split( ( config.n_head * config.head_size, config.n_query_groups * config.head_size, config.n_query_groups * config.head_size, ) ) qs = q.split(config.n_head // config.n_query_groups * config.head_size) ks = k.split(config.head_size) vs = v.split(config.head_size) interleaved = [t for group in zip(qs, ks, vs) for t in group] return torch.cat(interleaved) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: if any("lora" in wn for wn in lit_weights): raise ValueError("Checkpoints with LoRA weights cannot be converted. Call `scripts/merge_lora.py` first.") if any("adapter" in wn or "gating_factor" in wn for wn in lit_weights): raise NotImplementedError("Converting adapter models is not supported.") @torch.inference_mode() def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: """Convert a LitGPT trained checkpoint into a Hugging Face Transformers checkpoint.""" checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) config = Config.from_file(checkpoint_dir / "model_config.yaml") output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / "model.pth" if "falcon" in config.name: copy_fn = partial(copy_weights_falcon, config) elif config.name.startswith("Gemma-2"): copy_fn = partial(copy_weights_gemma_2, config) elif config.name.startswith("Gemma-3"): copy_fn = partial(copy_weights_gemma_3, config) elif config.name.lower().startswith("phi"): copy_fn = partial(copy_weights_phi, config) elif config.name.lower().startswith(("qwen2.5", "qwq")): copy_fn = partial(copy_weights_qwen_2_5, config) elif config.name.lower().startswith("olmo-2-"): copy_fn = partial(copy_weights_olmo2, config) elif config.name.lower().startswith("qwen3"): copy_fn = partial(copy_weights_qwen_3, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) else: copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} with incremental_save(output_path) as saver: lit_weights = lazy_load(checkpoint_dir / "lit_model.pth") lit_weights = lit_weights.get("model", lit_weights) check_conversion_supported(lit_weights) copy_fn(sd, lit_weights, saver=saver) gc.collect() saver.save(sd) ================================================ FILE: litgpt/scripts/convert_pretrained_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from pathlib import Path from pprint import pprint import torch from litgpt.utils import copy_config_files, extend_checkpoint_dir, incremental_save @torch.inference_mode() def convert_pretrained_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: """Convert a checkpoint after pretraining. The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training is finished. This script will export the state-dict of the model and place it in the chosen output folder, which then can be loaded by other scripts for inference, evaluation, etc. Args: checkpoint_dir: Path to a checkpoint directory produced by ``litgpt.pretrain``. output_dir: The output folder where the converted state-dict file and config files will be saved to. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) pprint(locals()) if output_dir.is_dir() and output_dir.glob("*"): raise FileExistsError( f"The output folder exists and is not empty: {str(output_dir)}." " Please delete it first or choose a different name." ) output_dir.mkdir(parents=True) checkpoint_file = checkpoint_dir / "lit_model.pth" output_checkpoint_file = output_dir / "lit_model.pth" # TODO: Consolidate sharded checkpoint if applicable # Extract the model state dict and save to output folder with incremental_save(output_checkpoint_file) as saver: print("Processing", checkpoint_file) full_checkpoint = torch.load(str(checkpoint_file), mmap=True) loaded_state_dict = full_checkpoint["model"] converted_state_dict = {} for param_name, param in loaded_state_dict.items(): saver.store_early(param) # remove prefix for compiled model (if any) param_name = param_name.replace("_orig_mod.", "") converted_state_dict[param_name] = param print(f"Saving converted checkpoint to {str(output_checkpoint_file)}.") saver.save(converted_state_dict) copy_config_files(checkpoint_dir, output_dir) ================================================ FILE: litgpt/scripts/download.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import importlib.util import os from contextlib import contextmanager from pathlib import Path from typing import List, Optional, Tuple from litgpt.config import configs from litgpt.constants import _HF_TRANSFER_AVAILABLE, _SAFETENSORS_AVAILABLE from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint def download_from_hub( repo_id: str, access_token: Optional[str] = os.getenv("HF_TOKEN"), tokenizer_only: bool = False, convert_checkpoint: bool = True, dtype: Optional[str] = None, checkpoint_dir: Path = Path("checkpoints"), model_name: Optional[str] = None, ) -> None: """Download weights or tokenizer data from the Hugging Face Hub. Arguments: repo_id: The repository ID in the format ``org/name`` or ``user/name`` as shown in Hugging Face. If "list" is provided as input, a list of the currently supported models in LitGPT and quits. access_token: Optional API token to access models with restrictions. tokenizer_only: Whether to download only the tokenizer files. convert_checkpoint: Whether to convert the checkpoint files to the LitGPT format after downloading. dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the dtype they are downloaded in. checkpoint_dir: Where to save the downloaded files. model_name: The existing config name to use for this repo_id. This is useful to download alternative weights of existing architectures. """ options = [f"{config['hf_config']['org']}/{config['hf_config']['name']}" for config in configs] if repo_id == "list": print("Please specify --repo_id . Available values:") print("\n".join(sorted(options, key=lambda x: x.lower()))) return if model_name is None and repo_id not in options: print( f"Unsupported `repo_id`: {repo_id}." "\nIf you are trying to download alternative " "weights for a supported model, please specify the corresponding model via the `--model_name` option, " "for example, `litgpt download NousResearch/Hermes-2-Pro-Llama-3-8B --model_name Llama-3-8B`." "\nAlternatively, please choose a valid `repo_id` from the list of supported models, which can be obtained via " "`litgpt download list`." ) return from huggingface_hub import snapshot_download if importlib.util.find_spec("hf_transfer") is None: print( "It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`" ) download_files = ["tokenizer*", "generation_config.json", "config.json"] if not tokenizer_only: bins, safetensors = find_weight_files(repo_id, access_token) if bins: # covers `.bin` files and `.bin.index.json` download_files.append("*.bin*") elif safetensors: if not _SAFETENSORS_AVAILABLE: raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE)) download_files.append("*.safetensors*") else: raise ValueError(f"Couldn't find weight files for {repo_id}") import huggingface_hub._snapshot_download as download import huggingface_hub.constants as constants previous = constants.HF_HUB_ENABLE_HF_TRANSFER if _HF_TRANSFER_AVAILABLE and not previous: print("Setting HF_HUB_ENABLE_HF_TRANSFER=1") constants.HF_HUB_ENABLE_HF_TRANSFER = True download.HF_HUB_ENABLE_HF_TRANSFER = True directory = checkpoint_dir / repo_id with gated_repo_catcher(repo_id, access_token): snapshot_download( repo_id, local_dir=directory, allow_patterns=download_files, token=access_token, ) constants.HF_HUB_ENABLE_HF_TRANSFER = previous download.HF_HUB_ENABLE_HF_TRANSFER = previous if convert_checkpoint and not tokenizer_only: print("Converting checkpoint files to LitGPT format.") convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name) def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]: from huggingface_hub import repo_info from huggingface_hub.utils import filter_repo_objects with gated_repo_catcher(repo_id, access_token): info = repo_info(repo_id, token=access_token) filenames = [f.rfilename for f in info.siblings] bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) return bins, safetensors @contextmanager def gated_repo_catcher(repo_id: str, access_token: Optional[str]): try: yield except OSError as e: err_msg = str(e) if "Repository Not Found" in err_msg: raise ValueError( f"Repository at https://huggingface.co/api/models/{repo_id} not found." " Please make sure you specified the correct `repo_id`." ) from None elif "gated repo" in err_msg: if not access_token: raise ValueError( f"https://huggingface.co/{repo_id} requires authentication, please set the `HF_TOKEN=your_token`" " environment variable or pass `--access_token=your_token`. You can find your token by visiting" " https://huggingface.co/settings/tokens." ) from None else: raise ValueError( f"https://huggingface.co/{repo_id} requires authentication. The access token provided by `HF_TOKEN=your_token`" " environment variable or `--access_token=your_token` may not have sufficient access rights. Please" f" visit https://huggingface.co/{repo_id} for more information." ) from None raise e from None ================================================ FILE: litgpt/scripts/merge_lora.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """This script merges the LoRA weights with the base model""" from pathlib import Path from pprint import pprint from typing import Any, Dict, Optional, Tuple import lightning as L import torch import yaml from litgpt.lora import GPT, Config, lora_filter, merge_lora_weights from litgpt.utils import check_valid_checkpoint_dir, extend_checkpoint_dir def merge_lora( checkpoint_dir: Path, pretrained_checkpoint_dir: Optional[Path] = None, precision: Optional[str] = None ) -> None: """Merges the LoRA weights with the base model. See ``litgpt finetune lora``. Creates a new ``lit_model.pth`` file by merging the LoRA weights (``lit_model.pth.lora``) with the original checkpoint weights. Arguments: checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of ``litgpt finetune lora``. pretrained_checkpoint_dir: Optional path to the checkpoint directory with the weights of the base model corresponding to the LoRA checkpoint. By default, this will automatically be inferred from the metadata in the given `checkpoint_dir` directory. Only set this if the base model's checkpoint directory has moved or was renamed. precision: Optional precision setting to instantiate the model weights in. By default, this will automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. """ checkpoint_dir = extend_checkpoint_dir(checkpoint_dir) if pretrained_checkpoint_dir is not None: pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir) pprint(locals()) check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth.lora") if pretrained_checkpoint_dir is not None: check_valid_checkpoint_dir(pretrained_checkpoint_dir) if (checkpoint_dir / "lit_model.pth").is_file(): print("LoRA weights have already been merged in this checkpoint.") return lora_params, meta_pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir) precision = precision if precision is not None else lora_precision if pretrained_checkpoint_dir is None: pretrained_checkpoint_dir = meta_pretrained_checkpoint_dir pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir) fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) with fabric.init_module(), torch.device("meta"): model = GPT(config) # we don't care about these to perform merging model.cos = None model.sin = None lora_path = checkpoint_dir / "lit_model.pth.lora" pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) lora_checkpoint = torch.load(str(lora_path), mmap=True) lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint) # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint) model.load_state_dict(pretrained_checkpoint, assign=True) # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype lora_dtype = next(iter(lora_checkpoint.values())).dtype model.to(dtype=lora_dtype, device="cpu") merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring state_dict = {k.replace("linear.", ""): v for k, v in model.state_dict().items() if not lora_filter(k, v)} save_path = checkpoint_dir / "lit_model.pth" torch.save(state_dict, save_path) fabric.print(f"Saved merged weights to {str(checkpoint_dir / 'lit_model.pth')!r}") def load_lora_metadata(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path, Optional[str]]: hparams_file = checkpoint_dir / "hyperparameters.yaml" if not hparams_file.is_file(): raise FileNotFoundError( f"The path {str(hparams_file)!r} is not a valid checkpoint directory. It is missing a" f" `hyperparameters.yaml` file. Please point to the checkpoint directory that was produced by" f" the `litgpt/finetune/lora.py` script." ) with open(hparams_file, encoding="utf-8") as file: hparams = yaml.safe_load(file) lora_params = {k: v for k, v in hparams.items() if k.startswith("lora_")} pretrained_checkpoint_dir = Path(hparams["checkpoint_dir"]) precision = hparams.get("precision") return lora_params, pretrained_checkpoint_dir, precision ================================================ FILE: litgpt/tokenizer.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json from pathlib import Path from typing import Iterable, Iterator, Optional, Union import torch from litgpt.utils import fix_and_load_json class Tokenizer: def __init__(self, checkpoint_dir: Union[Path, str]) -> None: checkpoint_dir = Path(checkpoint_dir) if not checkpoint_dir.exists(): raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}") self.model_name = checkpoint_dir.stem self.use_bos = self.check_if_bos_token_used(checkpoint_dir) self.bos_id = None self.eos_id = None # some checkpoints have both files, `.json` takes precedence if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): from tokenizers import Tokenizer as HFTokenizer self.processor = HFTokenizer.from_file(str(vocabulary_path)) self.backend = "huggingface" if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file(): with open(special_tokens_path, encoding="utf-8") as fp: config = json.load(fp) bos_token = config.get("bos_token") eos_token = config.get("eos_token") if bos_token is not None and isinstance(bos_token, dict): bos_token = bos_token.get("content") if eos_token is not None and isinstance(eos_token, dict): eos_token = eos_token.get("content") self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file(): try: with open(special_tokens_path, encoding="utf-8") as fp: config = json.load(fp) except json.JSONDecodeError: # Some files like the Llama 3.2 one have bugs with open(special_tokens_path, encoding="utf-8") as fp: json_string = fp.read() config = fix_and_load_json(json_string) if self.bos_id is None: self.bos_id = config.get("bos_token_id") if self.eos_id is None: self.eos_id = config.get("eos_token_id") elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): from sentencepiece import SentencePieceProcessor self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) self.backend = "sentencepiece" self.bos_id = self.processor.bos_id() self.eos_id = self.processor.eos_id() else: raise NotImplementedError # NOTE: A temporary fix until it's resolved on Tokenizers side. # LlaMA tokenizer strips leading spaces if to decode a single token at a time. # https://github.com/huggingface/transformers/issues/31643 self.apply_decoding_fix = None if (config_path := checkpoint_dir / "tokenizer_config.json").is_file(): with open(config_path, encoding="utf-8") as fp: self.apply_decoding_fix = "LlamaTokenizer" in json.load(fp)["tokenizer_class"] @property def vocab_size(self) -> int: if self.backend == "huggingface": return self.processor.get_vocab_size(with_added_tokens=False) if self.backend == "sentencepiece": return self.processor.vocab_size() raise RuntimeError def token_to_id(self, token: str) -> int: if self.backend == "huggingface": id_ = self.processor.token_to_id(token) elif self.backend == "sentencepiece": id_ = self.processor.piece_to_id(token) else: raise RuntimeError if id_ is None: raise ValueError(f"token {token!r} not found in the collection.") return id_ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): return False with open(tokenizer_config_path, encoding="utf-8") as fp: config = json.load(fp) # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only # `PreTrainedTokenizerFast` if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"): return True if "add_bos_token" in config: return config["add_bos_token"] # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True. # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2 return config.get("tokenizer_class") == "LlamaTokenizer" def encode( self, string: str, device: Optional[torch.device] = None, bos: Optional[bool] = None, eos: bool = False, max_length: int = -1, ) -> torch.Tensor: if self.backend == "huggingface": tokens = self.processor.encode(string).ids elif self.backend == "sentencepiece": tokens = self.processor.encode(string) else: raise RuntimeError(f"`{self.backend}` is not supported.") if tokens is None: raise ValueError("`self.processor` returned tokens of None value.") if bos or (bos is None and self.use_bos): if self.bos_id is None: raise NotImplementedError("This tokenizer does not have a defined bos token.") if not tokens or tokens[0] != self.bos_id: tokens = [self.bos_id] + tokens # if the processor misbehaves and adds `bos` token no matter what elif tokens and tokens[0] == self.bos_id: tokens = tokens[1:] if eos and (not tokens or tokens[-1] != self.eos_id): tokens = tokens + [self.eos_id] # if the processor misbehaves and adds `eos` token no matter what elif tokens and tokens[-1] == self.eos_id: tokens = tokens[:-1] if max_length > 0: tokens = tokens[:max_length] return torch.tensor(tokens, dtype=torch.int, device=device) def decode(self, tensor: torch.Tensor) -> str: tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() if len(tokens) == 1 and self.apply_decoding_fix: dummy_token_id = 33 # \x1e dummy_token = self.processor.decode([dummy_token_id]) if dummy_token != "\x1e": dummy_token_id = 165 # \x1e is different in salamandra tokenizers dummy_token = self.processor.decode([dummy_token_id]) return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :] return self.processor.decode(tokens) def decode_stream( self, token_stream: Iterable[torch.Tensor], device: Optional[torch.device] = None ) -> Iterator[str]: if self.backend == "huggingface": try: for token in token_stream: yield self.decode(token) except KeyboardInterrupt: return elif self.backend == "sentencepiece": # TODO: Is there a way to not have to do this? # This may actually affect our tokens per second. # sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens # meaning that we need to decode everything each time so_far = torch.tensor([], dtype=torch.long, device=device) decoded_so_far = "" try: for token in token_stream: so_far = so_far.to(device=token.device) so_far = torch.cat((so_far, token.view(-1))) decoded_new = self.decode(so_far) yield decoded_new[len(decoded_so_far) :] decoded_so_far = decoded_new except KeyboardInterrupt: return else: raise NotImplementedError(self.backend) ================================================ FILE: litgpt/types.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Type aliases used across LitGPT modules.""" from typing import Literal # Logger-related types LoggerChoice = Literal["csv", "tensorboard", "wandb", "mlflow", "litlogger"] """Valid logger choices for experiment tracking. Available options: - "csv": Local CSV file logging (default for most scripts) - "tensorboard": TensorBoard visualization (default for pretrain) - "wandb": Weights & Biases cloud tracking - "mlflow": MLflow experiment tracking - "litlogger": Lightning.ai native tracking """ ================================================ FILE: litgpt/utils.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Utility functions for training and inference.""" import inspect import json import math import os import pickle import random import re import shutil import subprocess import sys import warnings from dataclasses import asdict, is_dataclass from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union import lightning as L import psutil import torch import torch.nn as nn import torch.utils._device import yaml from lightning.fabric.loggers import CSVLogger, TensorBoardLogger from lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from lightning.pytorch.cli import instantiate_class from lightning.pytorch.loggers import MLFlowLogger, WandbLogger from packaging import version from torch.serialization import normalize_storage_type from typing_extensions import Self from litgpt.constants import ( _LITLOGGER_AVAILABLE, _SUPPORTED_LOGGERS, _THUNDER_AVAILABLE, ) from litgpt.types import LoggerChoice if TYPE_CHECKING: from litgpt import GPT, Config def init_out_dir(out_dir: Path) -> Path: if not isinstance(out_dir, Path): out_dir = Path(out_dir) if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ: return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir return out_dir def find_resume_path(resume: Union[bool, Literal["auto"], Path], out_dir: Path) -> Optional[Path]: if not resume or isinstance(resume, Path): return resume resume_path = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1])), default=None) if resume == "auto": return resume_path if resume is True and resume_path is None: raise FileNotFoundError( f"You passed `--resume=True`, but no checkpoint file was found in `--out_dir={out_dir}`." ) return resume_path def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: total = 0 for p in module.parameters(): if requires_grad is None or p.requires_grad == requires_grad: if hasattr(p, "quant_state"): # bitsandbytes 4bit layer support total += math.prod(p.quant_state.shape) else: total += p.numel() return total def reset_parameters(module: nn.Module) -> None: """Calls `reset_parameters` on the module and all its submodules.""" for mod in module.modules(): if callable(getattr(mod, "reset_parameters", None)): mod.reset_parameters() def check_valid_checkpoint_dir( checkpoint_dir: Path, model_filename: str = "lit_model.pth", verbose: bool = True, raise_error: bool = False, ignore_tokenizer_files: bool = False, ) -> None: files = { model_filename: (checkpoint_dir / model_filename).is_file(), "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(), } if not ignore_tokenizer_files: files.update( { "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (checkpoint_dir / "tokenizer.model").is_file(), "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), } ) if checkpoint_dir.is_dir(): if all(files.values()): # we're good return problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" else: problem = " is not a checkpoint directory" # list locally available checkpoints available = list(Path("checkpoints").glob("*/*")) if available: options = "\n".join([""] + [repr(str(p.resolve())) for p in available]) extra = f"\nYou have downloaded locally:{options}\n" else: extra = "" if verbose: error_message = ( f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n" f"{extra}\nSee all download options by running:\n litgpt download" ) print(error_message, file=sys.stderr) if raise_error: raise FileNotFoundError(f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}.") else: raise SystemExit(1) class SavingProxyForStorage: def __init__(self, obj, saver, protocol_version=5): self.protocol_version = protocol_version self.saver = saver if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): raise TypeError(f"expected storage, not {type(obj)}") # this logic is taken from PyTorch 2.0+ torch/serialization.py if isinstance(obj, torch.storage.TypedStorage): # PT upstream wants to deprecate this eventually... storage = obj._untyped_storage storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() storage_key = saver._write_storage_and_return_key(storage) location = torch.serialization.location_tag(storage) self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) def __reduce_ex__(self, protocol_version): assert False, "this should be handled with out of band" class SavingProxyForTensor: def __init__(self, tensor, saver, protocol_version=5): self.protocol_version = protocol_version self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) if reduce_args[0] == torch._utils._rebuild_tensor_v2: # for Tensors with Python attributes (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), ( "Please check for updates" ) storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) else: (storage, *other_reduce_args) = reduce_args assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), ( "Please check for updates" ) storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) self.reduce_args = (storage_proxy, *other_reduce_args) def __reduce_ex__(self, protocol_version): if protocol_version != self.protocol_version: raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") return self.reduce_ret_fn, self.reduce_args class IncrementalPyTorchPickler(pickle.Pickler): def __init__(self, saver, *args, **kwargs): super().__init__(*args, **kwargs) self.storage_dtypes = {} self.saver = saver self.id_map = {} # this logic is taken from PyTorch 2.0+ torch/serialization.py def persistent_id(self, obj): # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if isinstance(obj, SavingProxyForStorage): return obj.storage_info if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._untyped_storage storage_dtype = obj.dtype storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_dtype = torch.uint8 storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check if storage.data_ptr() != 0: if storage.data_ptr() in self.storage_dtypes: if storage_dtype != self.storage_dtypes[storage.data_ptr()]: raise RuntimeError( "Cannot save multiple tensors or storages that view the same data as different types" ) else: self.storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = self.id_map.get(storage._cdata) if storage_key is None: storage_key = self.saver._write_storage_and_return_key(storage) self.id_map[storage._cdata] = storage_key location = torch.serialization.location_tag(storage) return ("storage", storage_type, storage_key, location, storage_numel) return None class incremental_save: def __init__(self, name): self.name = name self.zipfile = torch._C.PyTorchFileWriter(str(name)) self.has_saved = False self.next_key = 0 self.protocol_version = 2 def __enter__(self): return self def store_early(self, tensor): if isinstance(tensor, torch.Tensor): return SavingProxyForTensor(tensor, self, protocol_version=self.protocol_version) raise TypeError(f"can only store tensors early, not {type(tensor)}") def save(self, obj): if self.has_saved: raise RuntimeError("have already saved") # Write the pickle data for `obj` data_buf = BytesIO() pickler = IncrementalPyTorchPickler(self, data_buf, protocol=self.protocol_version) pickler.dump(obj) data_value = data_buf.getvalue() self.zipfile.write_record("data.pkl", data_value, len(data_value)) self.has_saved = True def _write_storage_and_return_key(self, storage): if self.has_saved: raise RuntimeError("have already saved") key = self.next_key self.next_key += 1 name = f"data/{key}" if storage.device.type != "cpu": storage = storage.cpu() num_bytes = storage.nbytes() current_version = version.parse(torch.__version__) threshold_version = version.parse("2.2.2") if current_version <= threshold_version: self.zipfile.write_record(name, storage.data_ptr(), num_bytes) else: self.zipfile.write_record(name, storage, num_bytes) return key def __exit__(self, type, value, traceback): self.zipfile.write_end_of_file() T = TypeVar("T") def chunked_cross_entropy( logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128, ignore_index: int = -100, ) -> torch.Tensor: # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate # the memory usage in fine-tuning settings with low number of parameters. # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing # the memory spike's magnitude # lm_head was chunked (we are fine-tuning) if isinstance(logits, list): # don't want to chunk cross entropy if chunk_size == 0: logits = torch.cat(logits, dim=1) logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index) # chunk cross entropy logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] loss_chunks = [ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() # See [non_masked_elems div note] return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems)) # no chunking at all logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) if chunk_size == 0: return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index) # lm_head wasn't chunked, chunk cross entropy logit_chunks = logits.split(chunk_size) target_chunks = targets.split(chunk_size) loss_chunks = [ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() # [non_masked_elems div note]: # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that # results in a python int which is then passed back to torch division. By using the # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize. return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems)) def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: for checkpoint_name, attribute_name in mapping.items(): full_checkpoint_name = prefix + checkpoint_name if full_checkpoint_name in state_dict: full_attribute_name = prefix + attribute_name state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) return state_dict def get_default_supported_precision(training: bool) -> str: """ Return the default precision that is supported by the hardware: either `bf16` or `16`. Args: training: If True, returns '-mixed' version of the precision; if False, returns '-true' version. Returns: The default precision that is suitable for the task and is supported by the hardware. """ import torch if torch.cuda.is_available(): if torch.cuda.is_bf16_supported(): return "bf16-mixed" if training else "bf16-true" else: return "16-mixed" if training else "16-true" return "bf16-mixed" if training else "bf16-true" def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: if isinstance(fabric.strategy, FSDPStrategy): fabric.load_raw(checkpoint_path, model, strict=strict) elif isinstance(fabric.strategy, ModelParallelStrategy): state_dict = torch.load(checkpoint_path, mmap=True) load_from_full_model_state_dict( model=model, full_sd=state_dict, device=fabric.device, strict=strict, cpu_offload=True, ) else: state_dict = lazy_load(checkpoint_path) state_dict = state_dict.get("model", state_dict) model.load_state_dict(state_dict, strict=strict) def load_checkpoint_update( fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True ) -> None: if isinstance(fabric.strategy, FSDPStrategy): fabric.load_raw(checkpoint_path, model, strict=strict) else: state_dict = lazy_load(checkpoint_path) state_dict = state_dict.get("model", state_dict) adapter_cp = lazy_load(adapter_path) state_dict.update(adapter_cp) model.load_state_dict(state_dict, strict=strict) def load_from_full_model_state_dict( model: torch.nn.Module, full_sd: Dict[str, Any], device: torch.device, strict: bool = False, cpu_offload: bool = False, ): from torch.distributed._tensor import distribute_tensor meta_sharded_sd = model.state_dict() sharded_sd = {} print(meta_sharded_sd.keys()) for param_name, full_tensor in full_sd.items(): if "norm" not in param_name and "wte" not in param_name and "ln_f" not in param_name: param_name = param_name.replace(".weight", ".linear.weight") param_name = param_name.replace(".bias", ".linear.bias") else: param_name = param_name print(param_name) sharded_meta_param = meta_sharded_sd.get(param_name) full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) sharded_tensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements, ) if cpu_offload: sharded_tensor = sharded_tensor.cpu() sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) # choose `assign=True` since we cannot call `copy_` on meta tensor return model.load_state_dict(sharded_sd, strict=strict, assign=True) def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation # this assumes that all samples have a fixed length equal to the block size # which is most likely false during finetuning flops_per_seq = flops_per_token * max_seq_length attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) return flops_per_seq + attn_flops_per_seq def estimate_flops(model: "GPT", training: bool) -> int: """Measures estimated FLOPs for MFU. Refs: * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 """ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage # (~10%) compared to the measured FLOPs, making those lower but more realistic. # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. n_trainable_params = num_parameters(model, requires_grad=True) trainable_flops = flops_per_param( model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params ) # forward + backward + gradients (assumes no gradient accumulation) ops_per_step = 3 if training else 1 n_frozen_params = num_parameters(model, requires_grad=False) frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params) # forward + backward frozen_ops_per_step = 2 if training else 1 return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops class CycleIterator: """An iterator that cycles through an iterable indefinitely. Example: >>> iterator = CycleIterator([1, 2, 3]) >>> [next(iterator) for _ in range(5)] [1, 2, 3, 1, 2] Note: Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable. """ def __init__(self, iterable: Iterable) -> None: self.iterable = iterable self.epoch = 0 self._iterator = None def __next__(self) -> Any: if self._iterator is None: self._iterator = iter(self.iterable) try: return next(self._iterator) except StopIteration: self._iterator = iter(self.iterable) self.epoch += 1 return next(self._iterator) def __iter__(self) -> Self: return self def copy_config_files(source_dir: Path, out_dir: Path) -> None: """Copies the specified configuration and tokenizer files into the output directory.""" config_files = ["config.json", "generation_config.json", "model_config.yaml"] tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"] for file_name in config_files + tokenizer_files: src_path = source_dir / file_name if src_path.exists(): shutil.copy(src_path, out_dir) def CLI(*args: Any, **kwargs: Any) -> Any: from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) return CLI(*args, **kwargs) def capture_hparams() -> Dict[str, Any]: """Captures the local variables ('hyperparameters') from where this function gets called.""" caller_frame = inspect.currentframe().f_back locals_of_caller = caller_frame.f_locals hparams = {} for name, value in locals_of_caller.items(): if value is None or isinstance(value, (int, float, str, bool, Path)): hparams[name] = value elif is_dataclass(value): hparams[name] = asdict(value) else: hparams[name] = str(value) return hparams def save_config(config: "Config", checkpoint_dir: Path) -> None: config_dict = asdict(config) with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp: yaml.dump(config_dict, fp) def parse_devices(devices: Union[str, int]) -> int: if devices in (-1, "auto"): return torch.cuda.device_count() or 1 if isinstance(devices, int) and devices > 0: return devices raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}") def choose_logger( logger_name: LoggerChoice, out_dir: Path, name: str, log_interval: int = 1, log_args: Optional[Dict] = None, resume: Optional[bool] = None, **kwargs: Any, ): if logger_name == "csv": return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs) if logger_name == "tensorboard": return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs) if logger_name == "wandb": project = log_args.pop("project", name) run = log_args.pop("run", os.environ.get("WANDB_RUN_NAME")) group = log_args.pop("group", os.environ.get("WANDB_RUN_GROUP")) return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs) if logger_name == "mlflow": return MLFlowLogger(experiment_name=name, **kwargs) if logger_name == "litlogger": if not _LITLOGGER_AVAILABLE: raise ModuleNotFoundError(_LITLOGGER_AVAILABLE) from lightning.pytorch.loggers import LitLogger # Extract litlogger-specific args teamspace = log_args.pop("teamspace", None) if log_args else None metadata = log_args.pop("metadata", None) if log_args else None log_model = log_args.pop("log_model", False) if log_args else False save_logs = log_args.pop("save_logs", True) if log_args else True checkpoint_name = log_args.pop("checkpoint_name", None) if log_args else None return LitLogger( root_dir=(out_dir / "logs"), name=name, teamspace=teamspace, metadata=metadata, log_model=log_model, save_logs=save_logs, checkpoint_name=checkpoint_name, **kwargs, ) raise ValueError( f"`--logger_name={logger_name}` is not a valid option. Choose from {', '.join(_SUPPORTED_LOGGERS)}." ) def get_argument_names(cls): sig = inspect.signature(cls.__init__) return { name for name, param in sig.parameters.items() if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY] } def instantiate_bnb_optimizer(optimizer, model_parameters): if (isinstance(optimizer, str) and "AdamW" not in optimizer) or ( isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "") ): raise ValueError("The chosen quantization format only supports the AdamW optimizer.") import bitsandbytes as bnb if isinstance(optimizer, str): optimizer = bnb.optim.PagedAdamW(model_parameters) else: optim_args = get_argument_names(bnb.optim.PagedAdamW) allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()} optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs) return optimizer def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): # Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script: # bnb.optim.AdamW8bit # grokadamw.GrokAdamW # torch.optim.RMSprop if isinstance(optimizer, str): if "." in optimizer: class_module, class_name = optimizer.rsplit(".", 1) else: class_module, class_name = "torch.optim", optimizer module = __import__(class_module, fromlist=[class_name]) optimizer_cls = getattr(module, class_name) valid_params = set(inspect.signature(optimizer_cls).parameters) kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params} optimizer = optimizer_cls(model_parameters, **kwargs) elif isinstance(optimizer, dict): optimizer = dict(optimizer) class_module, class_name = optimizer["class_path"].rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) optimizer_cls = getattr(module, class_name) valid_params = set(inspect.signature(optimizer_cls).parameters) kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params} optimizer["init_args"].update(kwargs) optimizer = instantiate_class(model_parameters, optimizer) else: raise ValueError(f'Unrecognized "optimizer" value: {optimizer}') return optimizer def extend_checkpoint_dir(checkpoint_dir: Path) -> Path: new_checkpoint_dir = "checkpoints" / checkpoint_dir should_return_new_dir = ( not checkpoint_dir.is_dir() and checkpoint_dir.parts[0] != "checkpoints" and not checkpoint_dir.is_absolute() and new_checkpoint_dir.exists() ) return new_checkpoint_dir if should_return_new_dir else checkpoint_dir def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660): """ Checks the file size and raises a warning if it exceeds the size_limit. The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660 """ size = 0.0 if os.path.exists(checkpoint_path): size = os.path.getsize(checkpoint_path) if size > size_limit and str(device) == "cpu": warnings.warn( f"The file size of {checkpoint_path} is over {size_limit / 1024 / 1024 / 1024:.1f} GB. Using a model " "with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU." ) return size def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False): from litgpt.scripts.download import download_from_hub # moved here due to circular import issue checkpoint_dir = extend_checkpoint_dir(Path(model_name)) try: check_valid_checkpoint_dir( checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files ) except FileNotFoundError as e: if access_token is None: access_token = os.getenv("HF_TOKEN") if checkpoint_dir.parts[0] != "checkpoints" and not checkpoint_dir.is_absolute(): download_from_hub(repo_id=str(model_name), access_token=access_token) checkpoint_dir = Path("checkpoints") / checkpoint_dir else: raise e return checkpoint_dir def check_nvlink_connectivity(fabric=None): """Checks GPU connectivity for both NVIDIA and AMD GPUs. This function delegates to vendor-specific implementations based on the detected GPU vendor. """ if fabric is not None: custom_print = fabric.print else: custom_print = print if os.getenv("RANK", "0") == "0": try: if torch.cuda.is_available(): device_properties = torch.cuda.get_device_properties(0) gpu_name = device_properties.name.lower() if "nvidia" in gpu_name: _check_nvidia_connectivity(custom_print) elif "advanced micro devices" in gpu_name or "amd" in gpu_name: _check_amd_connectivity(custom_print) else: custom_print(f"Unrecognized GPU vendor: {device_properties.name}") else: custom_print("No GPUs available") except Exception as e: custom_print(f"An error occurred while checking GPU connectivity: {e}") def _check_nvidia_connectivity(custom_print): """Checks NVLink connectivity on NVIDIA GPUs.""" result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True) if result.returncode != 0: custom_print("Failed to run nvidia-smi") return lines = result.stdout.strip().split("\n") start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None) if start_index is None: custom_print("Failed to parse nvidia-smi output") return headers_line = lines[start_index] headers = headers_line.split() gpu_regex = re.compile(r"^GPU\d+$") gpu_count = len([header for header in headers if gpu_regex.match(header)]) all_nvlink = True for line in lines[start_index + 1 : start_index + 1 + gpu_count]: columns = line.split() connections = columns[1 : 1 + gpu_count] if not all("NV" in conn for conn in connections if conn != "X"): all_nvlink = False break if all_nvlink: custom_print("All GPUs are fully connected via NVLink.") else: custom_print( "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." ) def _check_amd_connectivity(custom_print): """Checks XGMI connectivity on AMD GPUs.""" result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True) if result.returncode != 0: custom_print("Failed to run rocm-smi") return lines = result.stdout.strip().split("\n") gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None) if gpu_header_index is None or gpu_header_index == 0: custom_print("Failed to parse rocm-smi output (no GPU headers found)") return header_line = lines[gpu_header_index - 1] headers = header_line.strip().split() gpu_regex = re.compile(r"^GPU\d+$") gpu_count = len([header for header in headers if gpu_regex.match(header)]) gpu_lines = [] for line in lines[gpu_header_index : gpu_header_index + gpu_count]: if re.match(r"^\s*GPU\d+", line): gpu_lines.append(line.strip()) if len(gpu_lines) != gpu_count: custom_print("Mismatch in GPU count when parsing rocm-smi output") return all_xgmi = True for line in gpu_lines: columns = line.split() connections = columns[1 : 1 + gpu_count] for conn in connections: if conn not in ("XGMI", "0"): all_xgmi = False break if not all_xgmi: break if all_xgmi: custom_print("All GPUs are fully connected via XGMI.") else: custom_print( "Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. " "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." ) def fix_and_load_json(s): # Remove trailing commas before } or ] s = re.sub(r",(\s*[}\]])", r"\1", s) # Insert missing commas between properties # Match positions where a value is followed by a newline and then a quote without a comma pattern = r'(?<=[}\]0-9truefalsenull"])\s*(\n\s*)"' replacement = r',\1"' s = re.sub(pattern, replacement, s) # Now try to parse the JSON try: return json.loads(s) except json.JSONDecodeError as e: raise ValueError(f"Failed to parse JSON after fixing: {e}") def create_finetuning_performance_report(training_time, token_counts, device_type): tok_sec = token_counts["raw_tokens_plus_prompt_template_and_padding"] / training_time output = f""" | ------------------------------------------------------ | Token Counts | - Input Tokens : {token_counts["raw_tokens"]:>5} | - Tokens w/ Prompt : {token_counts["raw_tokens_plus_prompt_template"]:>5} | - Total Tokens (w/ Padding) : {token_counts["raw_tokens_plus_prompt_template_and_padding"]:>5} | ----------------------------------------------------- | Performance | - Training Time : {training_time:.2f} s | - Tok/sec : {tok_sec:.2f} tok/s | ----------------------------------------------------- """ if device_type == "cuda": memory_used = torch.cuda.max_memory_allocated() / 1e9 output += "| Memory Usage \n" output += f"| - Memory Used : {memory_used:.02f} GB \n" output += "-------------------------------------------------------\n" return output def select_sft_generate_example(eval, data): if eval.evaluate_example == "first": if len(data.test_dataset.data): instruction = data.test_dataset.data[0]["instruction"] else: instruction = data.train_dataset.data[0]["instruction"] elif eval.evaluate_example == "random": if len(data.test_dataset.data): random_idx = random.randint(0, len(data.test_dataset.data) - 1) instruction = data.test_dataset.data[random_idx]["instruction"] else: random_idx = random.randint(0, len(data.train_dataset.data) - 1) instruction = data.train_dataset.data[random_idx]["instruction"] elif isinstance(eval.evaluate_example, int): index = eval.evaluate_example if len(data.test_dataset.data) > index: instruction = data.test_dataset.data[index]["instruction"] elif len(data.train_dataset.data) > index: instruction = data.train_dataset.data[index]["instruction"] else: raise IndexError(f"Index {index} is out of range for both test and training datasets.") else: raise ValueError(f"Unknown evaluation example type: {eval.evaluate_example}") return instruction def _RunIf(thunder: bool = False, **kwargs): import pytest from lightning.fabric.utilities.testing import _runif_reasons reasons, marker_kwargs = _runif_reasons(**kwargs) if thunder and not _THUNDER_AVAILABLE: # if we require Thunder, but it's not available, we should skip reasons.append("Thunder") return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs) def kill_process_tree(pid: int): """ Kill a process and all its child processes given the parent PID. """ try: parent = psutil.Process(pid) children = parent.children(recursive=True) for child in children: child.kill() parent.kill() except psutil.NoSuchProcess: pass # Process already exited ================================================ FILE: pyproject.toml ================================================ [build-system] build-backend = "setuptools.build_meta" requires = [ "setuptools>=68.2.2", "wheel>=0.41.2", ] [project] name = "litgpt" version = "0.5.12" description = "Hackable implementation of state-of-the-art open-source LLMs" readme = "README.md" license = { file = "LICENSE" } authors = [ { name = "Lightning AI", email = "contact@lightning.ai" }, ] requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", ] dependencies = [ # download models: "huggingface-hub>=0.30,<1.4", "jsonargparse[signatures]>=4.37,<=4.41; python_version>='3.10'", # required to work with Python >=3.10 "lightning>=2.6.1", "psutil==7.1.3", "safetensors>=0.4.3", # tokenization in most models: "tokenizers>=0.21", "torch>=2.7", # convert_hf_checkpoint "tqdm>4.66", ] optional-dependencies.compiler = [ # compilaton: "lightning-thunder>=0.2.dev20250119; python_version>='3.10' and sys_platform=='linux'", ] optional-dependencies.extra = [ "bitsandbytes>=0.42,<0.43; sys_platform=='darwin'", # quantization: "bitsandbytes>=0.45.2,<0.50; sys_platform=='linux' or sys_platform=='win32'", # litgpt.evaluate: "datasets>=2.18,<4", # download: "huggingface-hub[hf-transfer]>=0.21", "litdata==0.2.59", # litgpt logging: "litlogger>=0.1.7", # litgpt.deploy: "litserve>0.2", # lm-eval: pinned <0.4.9.1 due to trust_remote_code issues with datasets like logiqa. # See: https://github.com/EleutherAI/lm-evaluation-harness/issues/3171 "lm-eval>=0.4.2,<0.4.9.1", # litgpt.data.prepare_starcoder.py: "pandas>=1.9", "pyarrow>=15.0.2", # litgpt.data: "requests>=2.31", # llama-based models: "sentencepiece>=0.2", # litgpt.pretrain: "tensorboard>=2.14", "torchmetrics>=1.3.1", "transformers>=4.51.3,<4.57", # litdata, only on non-Windows: "uvloop>=0.2; sys_platform!='win32'", # litgpt.data.prepare_slimpajama.py: "zstandard>=0.22", ] optional-dependencies.test = [ "einops>=0.7", "protobuf>=4.23.4", "pytest>=8.1.1", "pytest-benchmark>=5.1", "pytest-dependency>=0.6", "pytest-rerunfailures>=14", "pytest-timeout>=2.3.1", ] urls.documentation = "https://github.com/lightning-AI/litgpt/tutorials" urls.homepage = "https://github.com/lightning-AI/litgpt" scripts.litgpt = "litgpt.__main__:main" [tool.setuptools.packages.find] include = [ "litgpt", "litgpt.*", ] exclude = [ ] [tool.setuptools.package-data] litgpt = [ "LICENSE", "README.md", ] [tool.ruff] target-version = "py38" line-length = 120 exclude = [ "build", "dist", "docs", ] lint.select = [ "E", "F", # see: https://pypi.org/project/pyflakes "I", # implementation for isort "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up "W", # see: https://pypi.org/project/pycodestyle ] #extend-select = [ # "C4", # see: https://pypi.org/project/flake8-comprehensions # "PT", # see: https://pypi.org/project/flake8-pytest-style # "RET", # see: https://pypi.org/project/flake8-return # "SIM", # see: https://pypi.org/project/flake8-simplify #] lint.ignore = [ "E501", # Line too long "E731", # Do not assign a lambda expression, use a def "E741", # todo: Ambiguous variable name "F841", # todo: Local variable is assigned to but never used ] # Use Google-style docstrings. lint.pydocstyle.convention = "google" [tool.codespell] #skip = '*.py' quiet-level = 3 ignore-words-list = """ tral, \ Rockerfeller """ [tool.pytest.ini_options] addopts = [ "--strict-markers", #"--doctest-modules", "--color=yes", "--disable-pytest-warnings", ] ================================================ FILE: tests/conftest.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import shutil import sys from pathlib import Path from typing import List, Optional import pytest import torch # support running without installing as a package, adding extensions to the Python path wd = Path(__file__).parent.parent.resolve() if wd.is_dir(): sys.path.append(str(wd)) else: import warnings warnings.warn(f"Could not find extensions directory at {wd}") @pytest.fixture() def fake_checkpoint_dir(tmp_path): os.chdir(tmp_path) checkpoint_dir = tmp_path / "checkpoints" / "tmp" checkpoint_dir.mkdir(parents=True) (checkpoint_dir / "lit_model.pth").touch() (checkpoint_dir / "model_config.yaml").touch() (checkpoint_dir / "tokenizer.json").touch() (checkpoint_dir / "tokenizer_config.json").touch() return checkpoint_dir class TensorLike: def __eq__(self, other): return isinstance(other, torch.Tensor) @pytest.fixture() def tensor_like(): return TensorLike() class FloatLike: def __eq__(self, other): return not isinstance(other, int) and isinstance(other, float) @pytest.fixture() def float_like(): return FloatLike() @pytest.fixture(autouse=True) def restore_default_dtype(): # just in case torch.set_default_dtype(torch.float32) @pytest.fixture(autouse=True) def destroy_process_group(): yield import torch.distributed if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @pytest.fixture def turn_off_tf32_and_set_seed(monkeypatch): monkeypatch.setenv("NVIDIA_TF32_OVERRIDE", "0") torch.manual_seed(42) yield torch.seed() class MockTokenizer: """A dummy tokenizer that encodes each character as its ASCII code.""" bos_id = 0 eos_id = 1 def encode(self, text: str, bos: Optional[bool] = None, eos: bool = False, max_length: int = -1) -> torch.Tensor: output = [] if bos: output.append(self.bos_id) output.extend([ord(c) for c in text]) if eos: output.append(self.eos_id) output = output[:max_length] if max_length > 0 else output return torch.tensor(output) def decode(self, tokens: torch.Tensor) -> str: return "".join(chr(int(t)) for t in tokens.tolist()) @pytest.fixture() def mock_tokenizer(): return MockTokenizer() @pytest.fixture() def alpaca_path(tmp_path): file = Path(__file__).parent / "data" / "_fixtures" / "alpaca.json" shutil.copyfile(file, tmp_path / "alpaca.json") return tmp_path / "alpaca.json" @pytest.fixture() def dolly_path(tmp_path): file = Path(__file__).parent / "data" / "_fixtures" / "dolly.json" shutil.copyfile(file, tmp_path / "dolly.json") return tmp_path / "dolly.json" @pytest.fixture() def longform_path(tmp_path): path = tmp_path / "longform" path.mkdir() for split in ("train", "val"): file = Path(__file__).parent / "data" / "_fixtures" / f"longform_{split}.json" shutil.copyfile(file, path / f"{split}.json") return path # https://github.com/Lightning-AI/lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/tests/tests_fabric/conftest.py#L140 def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = [] filtered, skipped = 0, 0 options = {"standalone": "PL_RUN_STANDALONE_TESTS", "min_cuda_gpus": "RUN_ONLY_CUDA_TESTS"} if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1": # special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests. # by deleting the key, we avoid filtering out the CPU tests del options["min_cuda_gpus"] for kwarg, env_var in options.items(): # this will compute the intersection of all tests selected per environment variable if os.getenv(env_var, "0") == "1": conditions.append(env_var) for i, test in reversed(list(enumerate(items))): # loop in reverse, since we are going to pop items already_skipped = any(marker.name == "skip" for marker in test.own_markers) if already_skipped: # the test was going to be skipped anyway, filter it out items.pop(i) skipped += 1 continue has_runif_with_kwarg = any( marker.name == "skipif" and marker.kwargs.get(kwarg) for marker in test.own_markers ) if not has_runif_with_kwarg: # the test has `@_RunIf(kwarg=True)`, filter it out items.pop(i) filtered += 1 if config.option.verbose >= 0 and (filtered or skipped): writer = config.get_terminal_writer() writer.write( f"\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the" f" filters {conditions}.\n{skipped} tests are marked as unconditional skips.\nIn total," f" {len(items)} tests will run.\n", flush=True, bold=True, purple=True, # oh yeah, branded pytest messages ) for test in items: if "test_hf_for_nemo" in test.nodeid and "Qwen/Qwen2.5-7B-Instruct" in test.nodeid: test.add_marker( # Don't use `raises=TypeError` because the actual exception is # wrapped inside `torch._dynamo.exc.BackendCompilerFailed`, # which prevents pytest from recognizing it as a TypeError. pytest.mark.xfail( reason="currently not working, see https://github.com/Lightning-AI/lightning-thunder/issues/2085", ) ) ================================================ FILE: tests/convert/__init__.py ================================================ ================================================ FILE: tests/convert/test_hf_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from unittest import mock import pytest import torch from litgpt import Config from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama, qkv_reassemble def test_llama2_70b_conversion(): shapes = { "model.embed_tokens.weight": (32000, 8192), "model.layers.0.input_layernorm.weight": (8192,), "model.layers.0.mlp.down_proj.weight": (8192, 28672), "model.layers.0.mlp.gate_proj.weight": (28672, 8192), "model.layers.0.mlp.up_proj.weight": (28672, 8192), "model.layers.0.post_attention_layernorm.weight": (8192,), "model.layers.0.self_attn.q_proj.weight": (8192, 8192), "model.layers.0.self_attn.k_proj.weight": (1024, 8192), "model.layers.0.self_attn.v_proj.weight": (1024, 8192), "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.1.input_layernorm.weight": (8192,), "model.layers.1.mlp.down_proj.weight": (8192, 28672), "model.layers.1.mlp.gate_proj.weight": (28672, 8192), "model.layers.1.mlp.up_proj.weight": (28672, 8192), "model.layers.1.post_attention_layernorm.weight": (8192,), "model.layers.1.self_attn.o_proj.weight": (8192, 8192), "model.layers.2.input_layernorm.weight": (8192,), "model.layers.2.mlp.down_proj.weight": (8192, 28672), "model.layers.2.mlp.gate_proj.weight": (28672, 8192), "model.layers.2.mlp.up_proj.weight": (28672, 8192), "model.layers.2.post_attention_layernorm.weight": (8192,), "model.layers.2.self_attn.o_proj.weight": (8192, 8192), "model.layers.3.input_layernorm.weight": (8192,), "model.layers.3.mlp.down_proj.weight": (8192, 28672), "model.layers.3.mlp.gate_proj.weight": (28672, 8192), "model.layers.3.mlp.up_proj.weight": (28672, 8192), "model.layers.3.post_attention_layernorm.weight": (8192,), "model.layers.3.self_attn.o_proj.weight": (8192, 8192), "model.layers.4.input_layernorm.weight": (8192,), "model.layers.4.mlp.down_proj.weight": (8192, 28672), "model.layers.4.mlp.gate_proj.weight": (28672, 8192), "model.layers.4.mlp.up_proj.weight": (28672, 8192), "model.layers.4.post_attention_layernorm.weight": (8192,), "model.layers.4.self_attn.o_proj.weight": (8192, 8192), "model.layers.5.mlp.gate_proj.weight": (28672, 8192), "model.layers.5.self_attn.o_proj.weight": (8192, 8192), } config = Config.from_name("Llama-2-70b-hf") holder = {} qkv_weights = {} with torch.device("meta"): weight_map = {k: torch.empty(s) for k, s in shapes.items()} copy_weights_hf_llama(config, qkv_weights, holder, weight_map) # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v` assert len(qkv_weights) == 1 # there are no loaded qkv weights assert all(v is None for qkv in qkv_weights.values() for v in qkv) # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} assert holder == { "transformer.h.0.attn.qkv.weight": (10240, 8192), "transformer.h.0.attn.proj.weight": (8192, 8192), "transformer.h.0.mlp.fc_1.weight": (28672, 8192), "transformer.h.0.mlp.fc_2.weight": (28672, 8192), "transformer.h.0.mlp.proj.weight": (8192, 28672), "transformer.h.0.norm_1.weight": (8192,), "transformer.h.0.norm_2.weight": (8192,), "transformer.h.1.attn.proj.weight": (8192, 8192), "transformer.h.1.mlp.fc_1.weight": (28672, 8192), "transformer.h.1.mlp.fc_2.weight": (28672, 8192), "transformer.h.1.mlp.proj.weight": (8192, 28672), "transformer.h.1.norm_1.weight": (8192,), "transformer.h.1.norm_2.weight": (8192,), "transformer.h.2.attn.proj.weight": (8192, 8192), "transformer.h.2.mlp.fc_1.weight": (28672, 8192), "transformer.h.2.mlp.fc_2.weight": (28672, 8192), "transformer.h.2.mlp.proj.weight": (8192, 28672), "transformer.h.2.norm_1.weight": (8192,), "transformer.h.2.norm_2.weight": (8192,), "transformer.h.3.attn.proj.weight": (8192, 8192), "transformer.h.3.mlp.fc_1.weight": (28672, 8192), "transformer.h.3.mlp.fc_2.weight": (28672, 8192), "transformer.h.3.mlp.proj.weight": (8192, 28672), "transformer.h.3.norm_1.weight": (8192,), "transformer.h.3.norm_2.weight": (8192,), "transformer.h.4.attn.proj.weight": (8192, 8192), "transformer.h.4.mlp.fc_1.weight": (28672, 8192), "transformer.h.4.mlp.fc_2.weight": (28672, 8192), "transformer.h.4.mlp.proj.weight": (8192, 28672), "transformer.h.4.norm_1.weight": (8192,), "transformer.h.4.norm_2.weight": (8192,), "transformer.h.5.attn.proj.weight": (8192, 8192), "transformer.h.5.mlp.fc_1.weight": (28672, 8192), "transformer.wte.weight": (32000, 8192), "lm_head.weight": (32000, 8192), # due to weight tying lm_head is in the converted weights } @pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) def test_convert_hf_checkpoint(tmp_path, model_name): with pytest.raises(ValueError, match="to contain .bin"): convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) bin_file = tmp_path / "foo.bin" bin_file.touch() with mock.patch("litgpt.scripts.convert_hf_checkpoint.lazy_load") as load: # bypass if-statement for weight tying if model_name == "Llama-2-7b-hf": load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))} convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) load.assert_called_with(bin_file) assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "model_config.yaml", "lit_model.pth"} # ensure that the config dict can be loaded config = Config.from_file(tmp_path / "model_config.yaml") assert isinstance(config, Config) def test_qkv_reassemble(): # MHA config = Config(n_embd=4, n_head=4) qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [16, 17, 18, 19], # key [32, 33, 34, 35], # value [4, 5, 6, 7], # query [20, 21, 22, 23], # key [36, 37, 38, 39], # value [8, 9, 10, 11], # query [24, 25, 26, 27], # key [40, 41, 42, 43], # value [12, 13, 14, 15], # query [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ) qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # key [24, 25, 26, 27], # key [28, 29, 30, 31], # key [32, 33, 34, 35], # value [36, 37, 38, 39], # value [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ), ) # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [16, 17, 18, 19], # key [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query [20, 21, 22, 23], # key [28, 29, 30, 31], # value ] ) qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # key [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ), ) # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # value ] ) qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # value ] ), ) ================================================ FILE: tests/convert/test_lit_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from dataclasses import asdict from unittest.mock import ANY import pytest import torch import yaml from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.falcon import FalconConfig, FalconForCausalLM from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM from transformers.models.phi.configuration_phi import PhiConfig from transformers.models.phi.modeling_phi import PhiForCausalLM from transformers.models.phi3.configuration_phi3 import Phi3Config from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from litgpt import GPT, Config from litgpt.scripts.convert_lit_checkpoint import ( check_conversion_supported, convert_lit_checkpoint, copy_weights_falcon, copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_gpt_neox, copy_weights_llama, copy_weights_phi, copy_weights_qwen_2_5, qkv_reassemble, ) from litgpt.utils import _RunIf @pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) def test_convert_lit_checkpoint(tmp_path, model_name): ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" config_path = tmp_path / "model_config.yaml" torch.save(ours_model.state_dict(), checkpoint_path) with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) output_dir = tmp_path / "out_dir" convert_lit_checkpoint(checkpoint_path.parent, output_dir) assert set(os.listdir(tmp_path)) == {"lit_model.pth", "model_config.yaml", "out_dir"} assert os.path.isfile(output_dir / "model.pth") # check checkpoint is unwrapped torch.save({"model": ours_model.state_dict()}, checkpoint_path) convert_lit_checkpoint(checkpoint_path.parent, output_dir) converted_sd = torch.load(output_dir / "model.pth") assert "model" not in converted_sd @torch.inference_mode() def test_against_falcon_40b(): ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32) theirs_config = FalconConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_kv_heads=ours_config.n_query_groups, new_decoder_architecture=True, parallel_attn=ours_config.parallel_residual, bias=ours_config.bias, ) ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_falcon(ours_config, theirs_state_dict, ours_state_dict) theirs_model = FalconForCausalLM(theirs_config) # assign must be set to True for torch.testing.assert_close to pass theirs_model.load_state_dict(theirs_state_dict, assign=True) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() def test_against_original_gpt_neox(): ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16) assert ours_config.padded_vocab_size == 512 theirs_config = GPTNeoXConfig( hidden_act="gelu", hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, initializer_range=0.02, intermediate_size=ours_config.intermediate_size, layer_norm_eps=1e-05, max_position_embeddings=ours_config.block_size, rotary_emb_base=10000, rotary_pct=ours_config.rotary_percentage, vocab_size=ours_config.padded_vocab_size, use_parallel_residual=ours_config.parallel_residual, ) ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_gpt_neox(ours_config, theirs_state_dict, ours_state_dict) theirs_model = GPTNeoXForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys assert all("inv_freq" in k for k in keys.missing_keys) # test end to end x = torch.randint(0, ours_config.padded_vocab_size, size=(2, ours_config.block_size), dtype=torch.int64) ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( "ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}] ) def test_against_hf_llama2(ours_kwargs): ours_config = Config.from_name( padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs ) T = 5 theirs_config = LlamaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_query_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict) theirs_model = LlamaForCausalLM(theirs_config) theirs_model.load_state_dict(theirs_state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1")) def test_against_mixtral(model_name): ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, n_expert=4, ) T = 5 theirs_config = MixtralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, num_local_experts=ours_config.n_expert, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict) theirs_model = MixtralForCausalLM(theirs_config) theirs_model.load_state_dict(theirs_state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32) ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) def test_against_olmo(model_name): ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, ) T = 5 theirs_config = OlmoConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_key_value_heads=ours_config.n_query_groups, max_positional_embeddings=T, attention_bias=ours_config.bias, rope_theta=ours_config.rope_base, tie_word_embeddings=(model_name == "OLMo-1B-hf"), ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config) # tie weights ours_model.lm_head.weight = ours_model.transformer.wte.weight ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=(model_name == "OLMo-1B-hf")) theirs_model = OlmoForCausalLM(theirs_config) keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() def test_against_original_open_llama_3b(): ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) T = 5 theirs_config = LlamaConfig( hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict) theirs_model = LlamaForCausalLM(theirs_config) theirs_model.load_state_dict(theirs_state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2")) def test_against_hf_phi(model_name): ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5 ) T = 5 theirs_config = PhiConfig( vocab_size=ours_config.padded_vocab_size, max_position_embeddings=ours_config.block_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, partial_rotary_factor=ours_config.rotary_percentage, ) ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict) theirs_model = PhiForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys assert all("inv_freq" in k for k in keys.missing_keys) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct",)) def test_against_hf_phi_3(model_name): ours_config = Config.from_name(model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256) T = 5 theirs_config = Phi3Config( attention_bias=ours_config.bias, head_dim=ours_config.head_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, num_key_value_heads=ours_config.n_query_groups, pad_token_id=ours_config.padded_vocab_size - 1, partial_rotary_factor=ours_config.rotary_percentage, rms_norm_eps=ours_config.norm_eps, rope_theta=ours_config.rope_base, vocab_size=ours_config.padded_vocab_size, ) ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict) theirs_model = Phi3ForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys assert all("inv_freq" in k for k in keys.missing_keys) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() def test_against_original_stablelm_zephyr_3b(): T = 5 ours_config = Config.from_name("stablelm-zephyr-3b", n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = AutoConfig.from_pretrained( "stabilityai/stablelm-zephyr-3b", trust_remote_code=True, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_key_value_heads=ours_config.n_head, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict) theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True, torch_dtype=torch.float32) theirs_model.load_state_dict(theirs_state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma(model_name, device, dtype): torch.set_default_dtype(dtype) T = 5 ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config).to(device) # tie weights ours_model.lm_head.weight = ours_model.transformer.wte.weight ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_model.load_state_dict( theirs_state_dict, strict=False, ) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-2-2b", "gemma-2-9b", "gemma-2-27b")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_2(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config).to(device) # tie weights ours_model.lm_head.weight = ours_model.transformer.wte.weight ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_gemma_2(ours_config, theirs_state_dict, ours_state_dict) theirs_model = Gemma2ForCausalLM(theirs_config).to(device) keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it")) @pytest.mark.parametrize( ("device", "dtype"), [ pytest.param(torch.device("cpu"), torch.float32, marks=[pytest.mark.flaky(reruns=3)]), pytest.param( torch.device("cuda"), torch.float16, marks=[ # todo: the reference does softmax upscaled to fp32 during attention # additionally, the final layernorm input is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config).to(device) # tie weights ours_model.lm_head.weight = ours_model.transformer.wte.weight ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict) theirs_model = Gemma3ForCausalLM(theirs_config).to(device) keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) def test_check_conversion_supported_adapter(): lit_weights = {"some.key.name": ANY, "error.key.gating_factor": ANY} with pytest.raises(NotImplementedError, match="Converting adapter"): check_conversion_supported(lit_weights=lit_weights) lit_weights = {"some.key.name": ANY, "error.key.adapter_bias": ANY} with pytest.raises(NotImplementedError, match="Converting adapter"): check_conversion_supported(lit_weights=lit_weights) def test_check_conversion_supported_lora(): lit_weights = {"some.key.name": ANY, "error.key.lora": ANY} with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"): check_conversion_supported(lit_weights=lit_weights) @torch.inference_mode() @pytest.mark.parametrize( "model_name", ( "Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B", "Qwen2.5-7B-Instruct-1M", ), ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_qwen_2_5(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Qwen2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.attn_bias, tie_word_embeddings=True, ) assert ours_config.intermediate_size == theirs_config.intermediate_size ours_model = GPT(ours_config).to(device) # tie weights ours_model.lm_head.weight = ours_model.transformer.wte.weight ours_state_dict = ours_model.state_dict() theirs_state_dict = {} copy_weights_qwen_2_5(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) theirs_model = Qwen2ForCausalLM(theirs_config).to(device) keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) assert not keys.unexpected_keys # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) def test_qkv_reassemble(): # MHA config = Config(n_embd=4, n_head=4) qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # key [24, 25, 26, 27], # key [28, 29, 30, 31], # key [32, 33, 34, 35], # value [36, 37, 38, 39], # value [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ) qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query [16, 17, 18, 19], # key [32, 33, 34, 35], # value [4, 5, 6, 7], # query [20, 21, 22, 23], # key [36, 37, 38, 39], # value [8, 9, 10, 11], # query [24, 25, 26, 27], # key [40, 41, 42, 43], # value [12, 13, 14, 15], # query [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ), ) # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # key [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ) qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [16, 17, 18, 19], # key [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query [20, 21, 22, 23], # key [28, 29, 30, 31], # value ] ), ) # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # value ] ) qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query [8, 9, 10, 11], # query [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # value ] ), ) ================================================ FILE: tests/convert/test_pretrained_checkpoint.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import torch from litgpt.scripts.convert_pretrained_checkpoint import convert_pretrained_checkpoint def test_convert_pretrained_checkpoint(tmp_path, fake_checkpoint_dir): # Pretend we made a checkpoint from pretraining pretrained_checkpoint = { "model": {"some.module.weight": torch.rand(2, 2), "_orig_mod.some.other.module.weight": torch.rand(2, 2)}, "the_optimizer": "optimizer_state", "other": 1, } torch.save(pretrained_checkpoint, fake_checkpoint_dir / "lit_model.pth") convert_pretrained_checkpoint(checkpoint_dir=fake_checkpoint_dir, output_dir=(tmp_path / "converted")) assert set(os.listdir(tmp_path / "converted")) == { "lit_model.pth", "model_config.yaml", "tokenizer_config.json", "tokenizer.json", } converted_checkpoint = torch.load(tmp_path / "converted" / "lit_model.pth") assert list(converted_checkpoint.keys()) == ["some.module.weight", "some.other.module.weight"] ================================================ FILE: tests/data/__init__.py ================================================ ================================================ FILE: tests/data/_fixtures/alpaca.json ================================================ [ { "instruction": "Give three tips for staying healthy.", "input": "", "output": "1. Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule." }, { "instruction": "What are the three primary colors?", "input": "", "output": "The three primary colors are red, blue, and yellow." }, { "instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom." }, { "instruction": "How can we reduce air pollution?", "input": "", "output": "There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances." }, { "instruction": "Pretend you are a project manager of a construction company. Describe a time when you had to make a difficult decision.", "input": "", "output": "I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities." }, { "instruction": "Identify the odd one out.", "input": "Twitter, Instagram, Telegram", "output": "Telegram" }, { "instruction": "Explain why the following fraction is equivalent to 1/4", "input": "4/16", "output": "The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4." }, { "instruction": "Write a short story in third person narration about a protagonist who has to make an important career decision.", "input": "", "output": "John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him." }, { "instruction": "Evaluate this sentence for spelling and grammar mistakes", "input": "He finished his meal and left the restaurant", "output": "He finished his meal and left the restaurant." }, { "instruction": "How did Julius Caesar die?", "input": "", "output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC." }, { "instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris." }, { "instruction": "Generate a list of ten items a person might need for a camping trip", "input": "", "output": "1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos" }, { "instruction": "Discuss the causes of the Great Depression", "input": "", "output": "The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade." }, { "instruction": "Classify the following into animals, plants, and minerals", "input": "Oak tree, copper or, elephant", "output": "Oak tree: Plant\n Copper or: Mineral\n Elephant: Animal" }, { "instruction": "Explain the use of word embeddings in Natural Language Processing", "input": "", "output": "Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks." }, { "instruction": "Describe the function of a computer motherboard", "input": "", "output": "A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types." }, { "instruction": "Reverse engineer this code to create a new version", "input": "def factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial", "output": "def factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial" }, { "instruction": "Propose an ethical solution to the problem of data privacy", "input": "", "output": "One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data." }, { "instruction": "Generate three verbs that mean the same as \"to apologize\"", "input": "", "output": "Acknowledge, express regret, make amends." }, { "instruction": "What does DNA stand for?", "input": "", "output": "DNA stands for deoxyribonucleic acid." }, { "instruction": "Compare and contrast the Cuban Missile Crisis and the Vietnam War.", "input": "", "output": "The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force." }, { "instruction": "Generate a list of random words.", "input": "", "output": "sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic." }, { "instruction": "Transcribe the recording into text.", "input": "Recording about the history of Waterloo", "output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years." } ] ================================================ FILE: tests/data/_fixtures/dolly.json ================================================ [ { "instruction": "When did Virgin Australia start operating?", "context": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.", "response": "Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.", "category": "closed_qa" }, { "instruction": "Which is a species of fish? Tope or Rope", "context": "", "response": "Tope", "category": "classification" }, { "instruction": "Why can camels survive for long without water?", "context": "", "response": "Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.", "category": "open_qa" }, { "instruction": "Alice's parents have three daughters: Amy, Jessy, and what\u2019s the name of the third daughter?", "context": "", "response": "The name of the third daughter is Alice", "category": "open_qa" }, { "instruction": "When was Tomoaki Komorida born?", "context": "Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.", "response": "Tomoaki Komorida was born on July 10,1981.", "category": "closed_qa" }, { "instruction": "If I have more pieces at the time of stalemate, have I won?", "context": "Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.", "response": "No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position", "category": "information_extraction" }, { "instruction": "Given a reference text about Lollapalooza, where does it take place, who started it and what is it?", "context": "Lollapalooza /\u02ccl\u0252l\u0259p\u0259\u02c8lu\u02d0z\u0259/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.\n\nLollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addiction.", "response": "Lollapalooze is an annual musical festival held in Grant Park in Chicago, Illinois. It was started in 1991 as a farewell tour by Perry Farrell, single of the group Jane's Addiction. The festival includes an array of musical genres including alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. The festivals welcomes an estimated 400,000 people each year and sells out annually. Some notable headliners include: the Red Hot Chili Peppers, Chance the Rapper, Metallica, and Lady Gauge. Lollapalooza is one of the largest and most iconic festivals in the world and a staple of Chicago.", "category": "closed_qa" }, { "instruction": "Who gave the UN the land in NY to build their HQ", "context": "", "response": "John D Rockerfeller", "category": "open_qa" }, { "instruction": "Why mobile is bad for human", "context": "", "response": "We are always engaged one phone which is not good.", "category": "brainstorming" }, { "instruction": "Who was John Moses Browning?", "context": "John Moses Browning (January 23, 1855 \u2013 November 26, 1926) was an American firearm designer who developed many varieties of military and civilian firearms, cartridges, and gun mechanisms \u2013 many of which are still in use around the world. He made his first firearm at age 13 in his father's gun shop and was awarded the first of his 128 firearm patents on October 7, 1879, at the age of 24. He is regarded as one of the most successful firearms designers of the 19th and 20th centuries and pioneered the development of modern repeating, semi-automatic, and automatic firearms.\n\nBrowning influenced nearly all categories of firearms design, especially the autoloading of ammunition. He invented, or made significant improvements to, single-shot, lever-action, and pump-action rifles and shotguns. He developed the first reliable and compact autoloading pistols by inventing the telescoping bolt, then integrating the bolt and barrel shroud into what is known as the pistol slide. Browning's telescoping bolt design is now found on nearly every modern semi-automatic pistol, as well as several modern fully automatic weapons. He also developed the first gas-operated firearm, the Colt\u2013Browning Model 1895 machine gun \u2013 a system that surpassed mechanical recoil operation to become the standard for most high-power self-loading firearm designs worldwide. He also made significant contributions to automatic cannon development.\n\nBrowning's most successful designs include the M1911 pistol, the water-cooled M1917, the air-cooled M1919, and heavy M2 machine guns, the M1918 Browning Automatic Rifle, and the Browning Auto-5 \u2013 the first semi-automatic shotgun. Some of these arms are still manufactured, often with only minor changes in detail and cosmetics to those assembled by Browning or his licensees. The Browning-designed M1911 and Hi-Power are some of the most copied firearms in the world.", "response": "John Moses Browning is one of the most well-known designer of modern firearms. He started building firearms in his father's shop at the age of 13, and was awarded his first patent when he was 24.\n\nHe designed the first reliable automatic pistol, and the first gas-operated firearm, as well inventing or improving single-shot, lever-action, and pump-action rifles and shotguns.\n\nToday, he is most well-known for the M1911 pistol, the Browning Automatic Rifle, and the Auto-5 shotgun, all of which are in still in current production in either their original design, or with minor changes. His M1911 and Hi-Power pistols designs are some of the most reproduced firearms in the world today.", "category": "information_extraction" }, { "instruction": "Who is Thomas Jefferson?", "context": "Thomas Jefferson (April 13, 1743 \u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.", "response": "Thomas Jefferson (April 13, 1743 \u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.\n\nAmong the nation's Founding Fathers, Jefferson is considered unmatched in his intellectual depth and breadth. His passionate writings and advocacy for human rights, including freedom of thought, speech, and religion, were a leading inspiration behind the American Revolution, which ultimately gave rise to the American Revolutionary War, American independence, and the United States Constitution. Jefferson's ideas were globally influential in shaping and inspiring the Age of Enlightenment, which proved transformational in the late 17th and 18th centuries. He was a leading proponent of democracy, republicanism, and individual rights, and produced formative documents and decisions at the state, national, and international levels.\n\nDuring the American Revolution, Jefferson represented Virginia in the Second Continental Congress in Philadelphia, which adopted the Declaration of Independence on July 4, 1776. As a Virginia legislator, he drafted a state law for religious freedom. He served as the second Governor of Virginia from 1779 to 1781, during the Revolutionary War. In 1785, Jefferson was appointed the United States Minister to France, and subsequently, the nation's first secretary of state under President George Washington from 1790 to 1793. Jefferson and James Madison organized the Democratic-Republican Party to oppose the Federalist Party during the formation of the First Party System. With Madison, he anonymously wrote the Kentucky and Virginia Resolutions in 1798 and 1799, which sought to strengthen states' rights by nullifying the federal Alien and Sedition Acts.\n\nJefferson and Federalist John Adams became friends as well as political rivals, serving in the Continental Congress and drafting the Declaration of Independence together. In the 1796 presidential election between the two, Jefferson came in second, which according to electoral procedure at the time, made him vice president to Adams. Jefferson challenged Adams again in 1800 and won the presidency. After his term in office, Jefferson eventually reconciled with Adams and they shared a correspondence that lasted 14 years. He and Adams both died on the same day, July 4, 1826, which was also the 50th anniversary of Declaration of Independence.\n\nAs president, Jefferson pursued the nation's shipping and trade interests against Barbary pirates and aggressive British trade policies. Starting in 1803, he promoted a western expansionist policy with the Louisiana Purchase, which doubled the nation's claimed land area. To make room for settlement, Jefferson began the process of Indian tribal removal from the newly acquired territory. As a result of peace negotiations with France, his administration reduced military forces. He was re-elected in 1804, but his second term was beset with difficulties at home, including the trial of former vice president Aaron Burr. In 1807, American foreign trade was diminished when Jefferson implemented the Embargo Act in response to British threats to U.S. shipping. The same year, Jefferson signed the Act Prohibiting Importation of Slaves.\n\nJefferson was a plantation owner, lawyer, and politician, and mastered many disciplines including surveying, mathematics, horticulture, and mechanics. He was also an architect in the Palladian tradition. Jefferson's keen interest in religion and philosophy led to his appointment as president of the American Philosophical Society. He largely shunned organized religion but was influenced by Christianity, Epicureanism, and deism. Jefferson rejected fundamental Christianity, denying Christ's divinity. A philologist, Jefferson knew several languages. He was a prolific letter writer and corresponded with many prominent people, including Edward Carrington, John Taylor of Caroline, and James Madison. In 1785, Jefferson authored Notes on the State of Virginia, considered perhaps the most important American book published before 1800. Jefferson championed the ideals, values, and teachings of the Enlightenment.\n\nSince the 1790s, Jefferson was rumored to have had children by his sister-in-law and slave Sally Hemings, leading to what is known as the Jefferson-Hemings controversy. A 1998 DNA test concluded that one of Sally Hemings's children, Eston Hemings, was of the Jefferson male line. According to scholarly consensus, based on documentary and statistical evaluation, as well as oral history, Jefferson probably fathered at least six children with Hemings, including four that survived to adulthood.\n\nAfter retiring from public office, Jefferson founded the University of Virginia. Presidential scholars and historians generally praise Jefferson's public achievements, including his advocacy of religious freedom and tolerance in Virginia, his peaceful acquisition of the Louisiana Territory from France without war or controversy, and his ambitious and successful Lewis and Clark Expedition. Some modern historians are critical of Jefferson's personal involvement with slavery. Jefferson is consistently ranked among the top ten presidents of American history.", "category": "information_extraction" }, { "instruction": "Who was Kyle Van Zyl playing against when he scored 36 of hisa teams 61 points?", "context": "Van Zyl joined the Eastern Province Kings Academy, where he played for the Eastern Province U19 side in the 2010 Under-19 Provincial Championship. He was a key player for the Eastern Province U21 side in the 2012 Under-21 Provincial Championship, scoring 71 points in eight appearances. Van Zyl was under the Top SARU Performers, scoring the most tries at 6 in the 2012 Provincial Under 21 in the Rugby Junior Provincials.\n\nThis included a record and a remarkable personal haul in their opening match, when he scored 36 of his team's points in a 61\u20133 victory over Boland U21, consisting of four tries and eight conversions and was awarded Man of the Match.", "response": "Kyle Van Zyl was playing against Boland U21 when he scored 36 points, leading his team to victory in a 61-3 win.", "category": "closed_qa" } ] ================================================ FILE: tests/data/_fixtures/longform_train.json ================================================ [ { "input": "What are the positions held by Beto O'Rourke, Lupe Valdez, and Veronica Escobar on decriminalizing unauthorized border crossings in the US?", "output": "The liberal mainstream media squealed with delight last week when they announced Ted Cruz was only 4 points ahead of Democrat Beto O\u2019Rourke.\nThe the bottom dropped out on Beto\u2019s campaign.\nFirst O\u2019Rourke told a liberal Texas crowd that there was nothing more American than kneeling for the National Anthem.\nBeto O\u2019Rourke is now pushing to legalize unauthorized border crossers into America.\nA trio of Democrats running for top offices in Texas want to decriminalize unauthorized border crossings.\nU.S. Rep. Beto O\u2019Rourke, who\u2019s challenging Republican Sen. Ted Cruz for his Senate seat, gubernatorial candidate Lupe Valdez, and congressional candidate Veronica Escobar told HuffPost they are in favor of such a move, which the report says would be simpler to implement than abolishing Immigration and Customs Enforcement, as some Democrats have demanded.\nDecriminalizing unauthorized border crossings would go against the Trump administration\u2019s \u201czero tolerance\u201d policy that led to family separations at the border. The Department of Homeland Security would still be able to charge unauthorized immigrants with civil violations, the report notes, but it would save the federal government billions in incarceration costs.", "source": "C4", "subset": "search-engine" }, { "input": "For example, for the original quote we shall use, 'One swallow does not a summer make'. If someone were to say 'One good quote does not a philosopher make' (No dig at Aristotle intended...), or 'A fancy suit does not a businessman make'. Essentially, a clear allusion to a quote, but not the quote itself.", "output": "The word I've heard used for this on the linguistics blog Language Log is snowclone (it's derived from phrases of the format 'If Eskimos have N words for snow, X surely have Y words for Z.') Here's a Language Log post that describes the moment the word was coined and that links to some examples: Snowclones: Lexicographical Dating to the Second And here's how it's defined by Collins Dictionaries online: a verbal formula that is adapted for reuse by changing only a few words so that the allusion to the original phrase remains clear You might also be interested in looking at the answers to the following questions: This book will change your wife trope? Is there a general name for this: 'The more you X the more you'll Y' A term describing the replacement of a specific word in a saying with one that ryhmes", "source": "Natural Instructions", "subset": "task225 english language answer generation" }, { "input": "Describe the U.S.-Mexico border. Respond in detail.", "output": "President Donald Trump\u2019s border wall with Mexico will have an initial price tag of between $12 billion and $15 billion, according to Congressional Republican leaders, but the method of paying for Trump\u2019s promised project is still not settled.\nAt the Congressional Republican Retreat in Philadelphia on Thursday, House Speaker Paul Ryan said the wall would cost up to $15 billion, more than the $8 billion Trump has estimated the wall will cost.\nA U.S. Border Patrol agent removes a ladder used by undocumented immigrants to climb a border fence on Oct. 18, 2016 near McAllen, Texas.\nU.S. Customs and Border Protection agents fly near the U.S.-Mexico border while on helicopter patrol on Oct. 18, 2016 near McAllen, Texas.\nA U.S. Border Patrol agent detains a group of undocumented immigrants on Oct. 18, 2016 near McAllen, Texas.\nPeople stand in line to cross legally into the United States from Mexico on Sept. 24, 2016 in Tijuana, Mexico.\nThe Rio Grande forms a stretch of the U.S.-Mexico border on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas. Big Bend is a rugged, vast and remote region along the U.S.-Mexico border and includes the Big Bend National Park.\nU.S. Border Patrol agents with a K-9 unit detain undocumented immigrants after they illegally crossed the U.S.-Mexico border on Oct. 18, 2016, in McAllen, Texas.\nThis photo made with a smart phone through night vision goggles shows the Rio Grande flowing along the U.S.-Mexico border, as seen from a U.S. Customs and Border Protection helicopter during a patrol over the U.S.-Mexico border on Oct. 18, 2016 in McAllen, Texas.\nThe moon rises over the swirling current of the Rio Grande on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas.\nA bullet-proof shield stands to aid U.S. Border Patrol agents on the U.S.-Mexico border on Oct. 3, 2016 in El Paso, Texas.\nA child plays in the Pacific surf near the U.S.-Mexico border fence on Sept. 25, 2016 in Tijuana, Mexico. The nearby Friendship Park is one of the few places on the 2,000-mile border where separated families are allowed to meet.\nDunes stretch into the distance near the U.S.-Mexico border on Sept. 27, 2016 in the Imperial Sand Dunes recreation area, California.\nMexican farm workers hoe a cabbage field on Sept. 27, 2016, in Holtville, Calif. Thousands of Mexican seasonal workers legally cross over daily from Mexicali, Mexico to work the fields of Imperial Valley, Calif., which is some of the most productive farmland in the United States.\nA man looks through the U.S.-Mexico border fence into the United States on Septt. 25, 2016 in Tijuana, Mexico.\nA Studebaker sits outside a cotton field near the U.S.-Mexico border on Oct. 14, 2016 near Fort Hancock, Texas.\nBodies of suspected undocumented immigrants lie in the morgue at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz. Hundreds of migrants die every year in the desert while crossing illegally from Mexico into the United States.\nPersonal effects found on the body of a suspected undocumented immigrant are stored at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz.\nDinner tables await immigrants, many of them deportees from the United States, at the Casa del Migrante Sept. 23, 2016 in Tijuana, Mexico. The shelter, run by Catholic priests, is part of the Coalicion Pro Defensa, which aids immigrants and asylum applicants seeking entry into the United States. The coalition of immigrant advocates has helped thousands of migrants on their journey north and many more after being being deported from the United States.\nMexicans enjoy a late afternoon near the U.S.-Mexico border fence, which ends in the Pacific Ocean, on Sept. 25, 2016 in Tijuana, Mexico. Friendship Park, located on the border between the two countries is one of the few places on the 2,000-mile border where separated families are allowed to meet.\nTrump signed two executive orders on Wednesday directing the construction of the border wall and increasing the number of immigration enforcement officers carrying out deportation, as well as border patrol forces.", "source": "C4", "subset": "instruction" }, { "input": "Must a Boss use their Level Up ability?\n\nMost Level Up abilities clearly benefit the Boss, but Robobo's is:\n\nEach opponent mush choose and destroy one Room in his or her Dungeon.\n\nThe opponents get to destroy a room (often quite beneficial), and Robobo doesn't!\nIs there any way - short of only building a four room dungeon - that Robobo can avoid using this ability? What play strategies can Robobo use to maximize his Level Up ability?", "output": "According to the rules, your boss's \"Level Up\" ability is triggered when your dungeon is 5 rooms in size. \"Triggering\" implies necessity. So yes, the only way to avoid levelling up is by not building a fifth room.\nSecondly, and more to your concern... according to the Advanced Rules (see the entry for Robobo in the glossary using the link above), using Robobo's level up does not trigger the \"destroy this Room\" ability for the destroyed room.", "source": "StackExchange", "subset": "boardgames" }, { "input": "What is Podd\u0119bice?", "output": "Podd\u0119bice () is a village in the administrative district of Gmina W\u0142oc\u0142awek, within W\u0142oc\u0142awek County, Kuyavian-Pomeranian Voivodeship, in north-central Poland. It lies approximately south of W\u0142oc\u0142awek and south-east of Toru\u0144.", "source": "Wikipedia", "subset": "chatbot" }, { "input": "Describe the history of MV \"Alexia\" Respond in 4 sentences.", "output": "MV \"Alexia\" was one of nine Royal Dutch/Shell oil tankers converted to become a Merchant Aircraft Carrier (MAC ship). The group is collectively known as the \"Rapana\" class.\nMV \"Alexia\" was built at Bremer Vulkan and completed in April, 1935 as an oil tanker for the Anglo Saxon Royal Dutch/Shell line. She sustained severe damage in two separate U-boat attacks in 1940 and 1942.", "source": "Wikipedia", "subset": "instruction" }, { "input": "What is the definition of absurdity?", "output": "1. the state or quality of being absurd.\nAttending the Gerald R. Ford Symposium on Humor and the Presidency is sort of like attending the Ayatollah Khomeini Symposium on the sexual revolution \u2014Pat Paulsen, at September 19, 1986 symposium in Grand Rapids, Michigan.\n2. absurdity - a ludicrous folly; \"the crowd laughed at the absurdity of the clown's behavior\"\nnoun ridiculousness, nonsense, folly, stupidity, foolishness, silliness, idiocy, irrationality, incongruity, meaninglessness, daftness (informal), senselessness, illogicality, ludicrousness, unreasonableness, preposterousness, farcicality, craziness (informal), b\u00eatise (rare), farcicalness, illogicalness I get angry at the absurdity of a situation.\nfolly, foolery, foolishness, idiocy, imbecility, insanity, lunacy, madness, nonsense, preposterousness, senselessness, silliness, tomfoolery, zaniness.\nBut in the Epic poem the absurdity passes unnoticed.\nFor if absurdity be the subject of laughter, doubt you not but great boldness is seldom without some absurdity.\nI am temperate to the verge of absurdity,\" replied the Tramp.\nPoets, of course, may be satisfactorily read in volumes of, selections; but to me, at least, a book of brief extracts from twenty or a hundred prose authors is an absurdity.\nThen,\" suggested the idea, with a blush for its own absurdity, \"why not go on pilgrimage and seek her?\nJudges and starters have been conveniently blind to this absurdity, but the public demonstration off St.\nAnd now that the providential occurrence was apparently close at hand, it would have been sheer absurdity to think that the supply would be short of the need: as absurd as a faith that believed in half a miracle for want of strength to believe in a whole one.\nutterances, the absurdity being attested by his motley costume.\nOr what greater absurdity can there be than putting before us an old man as a swashbuckler, a young man as a poltroon, a lackey using fine language, a page giving sage advice, a king plying as a porter, a princess who is a kitchen-maid?\nNow to say that the honour I here mean, and which was, I thought, all the honour I could be supposed to mean, will uphold, much less dictate an untruth, is to assert an absurdity too shocking to be conceived.\nHence the absurdity of the interview; the gulf between them was economic as well as spiritual.", "source": "C4", "subset": "chatbot" }, { "input": "Can felons run for federal office in Minnesota?", "output": "Minnesota law doesn't block felons from running for federal office.\nEven if enough voters choose Leonard J. Richards as the DFL candidate for U.S. Senate, there\u2019s no way he will ever get to take the oath of office and begin a six-year term.\nRichards is already serving a lifetime term. In Stillwater prison. For murder. Make that two murders.\nYes, it is legal in Minnesota for felons to run for office, so long as it is a federal seat. Nobody knows that better than Richards, who is trying to wrest the party nod away from incumbent Amy Klobuchar.\nNow 75 years old and sporting Department of Corrections ID No. 149837, Richards has run for federal office several times \u2014 without a victory \u2014 since his imprisonment, most recently when he sought the seat that U.S. Rep. Tom Emmer now holds.\nIn 1992, Richards ran in the DFL primary for the Eighth Congressional District seat and received more than 14,500 votes. He ran for the U.S. Senate in the DFL primary in 1994, winning more than 4,000 votes.\nMinnesota law does not permit inmates to run for a state-level office.\nThe official ballot for Minnesota\u2019s primary next month lists double murderer Leonard Richards among those seeking the DFL nomination for U.S. Senate.\nRichards was convicted of murder twice in Hennepin County for the 1982 slaying of his half-sister, May Wilson, and the 1987 shooting death of his attorney, Robert Stratton. His life sentence imposed nearly 30 years ago offers no parole.\nRichards did not respond to a message Thursday seeking an interview.", "source": "C4", "subset": "chatbot" }, { "input": "What is the purpose of consciousness?", "output": "Cardiff University and University College London provide funding as founding partners of The Conversation UK.\nMost experts think that consciousness can be divided into two parts: the experience of consciousness (or personal awareness), and the contents of consciousness, which include things such as thoughts, beliefs, sensations, perceptions, intentions, memories and emotions.\nIt\u2019s easy to assume that these contents of consciousness are somehow chosen, caused or controlled by our personal awareness \u2013 after all, thoughts don\u2019t exist until until we think them. But in a new research paper in Frontiers of Psychology, we argue that this is a mistake.\nWe suggest that our personal awareness does not create, cause or choose our beliefs, feelings or perceptions. Instead, the contents of consciousness are generated \u201cbehind the scenes\u201d by fast, efficient, non-conscious systems in our brains. All this happens without any interference from our personal awareness, which sits passively in the passenger seat while these processes occur.\nPut simply, we don\u2019t consciously choose our thoughts or our feelings \u2013 we become aware of them.\nIf this sounds strange, consider how effortlessly we regain consciousness each morning after losing it the night before; how thoughts and emotions \u2013 welcome or otherwise \u2013 arrive already formed in our minds; how the colours and shapes we see are constructed into meaningful objects or memorable faces without any effort or input from our conscious mind.\nConsider that all the neuropsychological processes responsible for moving your body or using words to form sentences take place without involving your personal awareness. We believe that the processes responsible for generating the contents of consciousness do the same.\nOur thinking has been influenced by research into neuropsychological and neuropsychiatric disorders, as well as more recent cognitive neuroscience studies using hypnosis. The studies using hypnosis show that a person\u2019s mood, thoughts and perceptions can be profoundly altered by suggestion.\nIn such studies, participants go through a hypnosis induction procedure, to help them to enter a mentally focused and absorbed state. Then, suggestions are made to change their perceptions and experiences.\nFor example, in one study, researchers recorded the brain activity of participants when they raised their arm intentionally, when it was lifted by a pulley, and when it moved in response to a hypnotic suggestion that it was being lifted by a pulley.\nSimilar areas of the brain were active during the involuntary and the suggested \u201calien\u201d movement, while brain activity for the intentional action was different. So, hypnotic suggestion can be seen as a means of communicating an idea or belief that, when accepted, has the power to alter a person\u2019s perceptions or behaviour.\nAll this may leave one wondering where our thoughts, emotions and perceptions actually come from. We argue that the contents of consciousness are a subset of the experiences, emotions, thoughts and beliefs that are generated by non-conscious processes within our brains.\nThis subset takes the form of a personal narrative, which is constantly being updated. The personal narrative exists in parallel with our personal awareness, but the latter has no influence over the former.\nThe personal narrative is important because it provides information to be stored in your autobiographical memory (the story you tell yourself, about yourself), and gives human beings a way of communicating the things we have perceived and experienced to others.\nThis, in turn, allows us to generate survival strategies; for example, by learning to predict other people\u2019s behaviour. Interpersonal skills like this underpin the development of social and cultural structures, which have promoted the survival of human kind for millennia.\nSo, we argue that it is the ability to communicate the contents of one\u2019s personal narrative \u2013\u2013 and not personal awareness \u2013 that gives humans their unique evolutionary advantage.\nIf the experience of consciousness does not confer any particular advantage, it\u2019s not clear what its purpose is. But as a passive accompaniment to non-conscious processes, we don\u2019t think that the phenomenon of personal awareness has a purpose, in much the same way that rainbows do not. Rainbows simply result from the reflection, refraction and dispersion of sunlight through water droplets \u2013 none of which serves any particular purpose.\nOur conclusions also raise questions about the notions of free will and personal responsibility. If our personal awareness does not control the contents of the personal narrative which reflects our thoughts, feelings, emotions, actions and decisions, then perhaps we should not be held responsible for them.\nIn response to this, we argue that free will and personal responsibility are notions that have been constructed by society. As such, they are built into the way we see and understand ourselves as individuals, and as a species. Because of this, they are represented within the non-conscious processes that create our personal narratives, and in the way we communicate those narratives to others.\nJust because consciousness has been placed in the passenger seat, does not mean we need to dispense with important everyday notions such as free will and personal responsibility. In fact, they are embedded in the workings of our non-conscious brain systems. They have a powerful purpose in society and have a deep impact on the way we understand ourselves.", "source": "C4", "subset": "chatbot" }, { "input": "What is Nemaha, IA? Respond briefly.", "output": "Nemaha is a city in Sac County, Iowa, United States. The population was 85 at the 2010 census.\nAccording to the United States Census Bureau, the city has a total area of , all of it land.", "source": "Wikipedia", "subset": "chatbot" }, { "input": "In the book I am reading (Patrick Fairbairn, The Interpretation of Prophecy, 2nd ed., 1865), the abbreviation 'sq.' is used sometimes in referencing both Scripture and other books (as 'Typology of Scripture, vol. i. p. 100 sq.' or 'Ezek xxvi. 7, sq.'). What does 'sq.' mean in this context? (Instances of this abbreviation's usage may be found at Google Books here.).", "output": "Sq. is short for the Latin word sequiturque, which means the following. So p. 100 sq. means page 100 and those following pages. Here's the source.", "source": "Natural Instructions", "subset": "task225 english language answer generation" }, { "input": "\"Describe Brugnens.\" Respond in 11 words.", "output": " Brugnens is a commune in the Gers department in southwestern France.", "source": "Wikipedia", "subset": "instruction" }, { "input": "What is the DeLand Municipal Airport? Respond in 1 sentences.", "output": "DeLand Municipal Airport , also known as Sidney H. Taylor Field, is a city-owned public-use airport located three nautical miles (5\u00a0km) northeast of the central business district of DeLand, a city in Volusia County, Florida, United States.", "source": "Wikipedia", "subset": "chatbot" }, { "input": "Respond in 203 words. Describe the plot of the 1976 Filipino drama film \"Minsa'y Isang Gamu-gamo\".", "output": "Minsa'y Isang Gamu-gamo (\"\") is a 1976 Filipino drama film directed by Lupita Aquino-Kashiwahara and written by Marina Feleo-Gonzales. It concerns a Filipino nurse, Cora de la Cruz, who dreams of moving to America. When her brother is killed, her ideas change. The film criticises American military presence in the Philippines. It was an entry to the 1976 Metro Manila Film Festival and won five awards at the 25th FAMAS Awards, including those for best picture, director and screenplay. The film was restored by the ABS-CBN Film Restoration through Kantana Post Production in Thailand and the restored version was released on October 12, 2018 as part of Cinema One Originals film festival.\nThe de la Cruises and the Santos are two lower-middle-class families who live in Pampanga. Cora de la Cruz is a nurse who dreams of living in the United States. Her papers are ready and she organises employment at an American hospital. She hopes to get a green card allowing her to stay, achieve immigrant status and then bring her family to America for a better life. Cora's mother and younger brother encourage her but her father, Ingkong, disagrees. He believes that moving to America is a betrayal of their country.", "source": "Wikipedia", "subset": "instruction" }, { "input": "Hey, when does Marvel's Runaways Season 2 start?", "output": "The second season picks up where the first left off.\nHulu has announced that the second season of Marvel's Runaways will premiere December 21st, Deadline reports. And there won't be any weekly installments this time around as there were with the first season because all 13 episodes are getting dropped at once. Hulu renewed the series for a second season just as the first one wrapped up.\nSeason 2 will kick off right where the first ended -- with the kids having escaped their evil parents and setting out their own. Now they'll have to figure out how to live with each other, how to take care of themselves and how to take down their parents' nefarious Pride squad. But they might also have to sniff out a mole since someone's sending the Pride leader mysterious messages. Meanwhile, the Pride is bound and determined to get them back.", "source": "C4", "subset": "chatbot" }, { "input": "I need to serve at a dinner party, but I'm not sure where to start. Can you give me 10 actionable steps?", "output": "Which way should you pass the plates? How should you clear the table? Serving at a dinner party isn't exactly as easy as pie. Here are some basic guidelines to help you out in your next dinner party.\n1. Steamed or roasted vegetables, rice, salad, potatoes, and condiments are some of the things you can usually leave in the middle of the table for guests to help themselves to. This isn't strictly necessary, though. If you prefer to keep things simple, you don't have to leave anything more than salt and pepper on the table.\n2. This is known as \"plating-up.\" Don't expect guests to help themselves to food that requires assembly or is difficult to manage. The only way they'll get to see your creativity is if you put it out on the plate for them before you even get to the dinner table.\nKeep the rim of the plate clear of sauces, spills, anything; it frames the food.\n3. You can do this if you want to keep with tradition and the occasion is quite formal. Otherwise, choose one end of the table and move around it in a clockwise fashion, regardless of the genders of your guests.\n\nPass all dishes from the left. Guests and servers should pass dishes from left. The logic behind this is that most people are right-handed and this allows them them to serve themselves from the dish while it is being held by the passer. Nowadays it is less likely that the passer will continue to hold the dish, but will expect you to take it, so if you are left-handed, it won't be a problem. Place the dish down on your side plate to serve from it.\n\nAs the cook, or host, always serve yourself last. This is polite and also sensible, since you'll probably be busy anyway with host's duties.\n4. They'll get fidgety, anxious and gossipy about what you're doing.\n5. On the other hand, do not ever go into details about how the flesh portion of the meal was hunted/killed. This is bad taste and makes some guests very queasy. Leave it for discussion around the fireplace with a like-minded friend after dinner.\n6. The host or hired help should clear no more than two plates at a time to avoid bumping guests and interfering with their eating. There is nothing more annoying than the server's elbows in your face when you're just about to take the next bite.\n7. Preferably the noises should not reach the guests but this is unrealistic for most homes. Just do it as quietly as possible and try not to clank, crack, break or drop the dishes. The last thing you need on top of anything else is a dropped plate to clean up.\n8. This means all the dishes on the table, the condiments and the side plates. If you haven't already set out the dessert spoons, this is the time to do so.\n9. Chocolates will have their own method of getting around the table; it has been suspected they have legs...\n10. Don't take advantage of the situation because that guest is there to enjoy himself too but don't hesitate to ask for a quick hand with a simple task that won't risk spills on their clothes.\n", "source": "WikiHow", "subset": "main" }, { "input": "What can you tell me about Brushcreek Township? Respond in 2 sentences.", "output": "Brushcreek Township is one of the seventeen townships of Highland County, Ohio, United States. As of the 2010 census the population was 1,381, of whom 1,248 lived in the unincorporated portion of the township.", "source": "Wikipedia", "subset": "chatbot" } ] ================================================ FILE: tests/data/_fixtures/longform_val.json ================================================ [ { "input": "The Big Mistake\n\nThis day was full of joy and happiness, but something went wrong after when she told him that she is pregnant. \nShe was so happy about this news and she thought that he will be happy as well, but... At the moment when she told him that she is pregnant, he just frozen. It was too much for him.\nHe thought it's easier to walk away and avoid this conversation with her. He tried to understand what's going on with him. It was weird how many questions and doubts came up on his head. One part of him was really happy, because to be a father is his one of the biggest dream, but on the other hand he couldn't handle this news. \nThis woman, who stands next to him, is his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loves him as much as he loves her, but everything was just too wrong on this moment. Her voice, her smile and even her body was to wrong on this moment.\nShe couldn't understand what\u2019s going on, because it was just three weeks ago when they had a conversation about kids and family and he said that he would like to have kids with her. And now, he just avoids her and he is not happy about the news. \nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they love each other and they care about each other, but it was too much for him. He stood up and said to her girlfriend that he needs a break and he have to be alone. Michael closed the door and knew at that moment he had made a mistake.\n\nFix all grammar errors and style issues and rewrite this document", "output": "The Big Mistake\n\nThis day was full of joy and happiness, but something went wrong when she told him that she was pregnant. \nShe was so happy about this news and she thought that he would be happy as well, but... At the moment when she told him that she was pregnant, he just frozen. It was too much for him.\nHe thought it was easier to walk away and avoid this conversation with her. He tried to understand what was going on with him. It was weird how many questions and doubts came into his head. One part of him was really happy, because to be a father was one of his biggest dreams, but on the other hand, he couldn't handle this news. \nThis woman, who was standing next to him, was his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loved him as much as he loved her, but everything was just all wrong at this moment. Her voice, her smile and even her body was wrong on this moment.\nShe couldn't understand what was going on, because it was just three weeks ago that they had a conversation about kids and family and he said that he would like to have kids with her. And now, he was just avoiding her and he was not happy about the news. \nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they loved each other and they cared about each other, but it was too much for him. He stood up and said to his girlfriend that he needed a break and he had to be alone. Michael closed the door and knew at that moment he had made a mistake.\n", "source": "BEA-GEC", "subset": "main" }, { "input": "Although the fact that personal vehicle like car now is popular, in the future, the stance of public transport is still very crucial. \n\nNowaday, people use fossil fuel as the most important source of energy for travelling, manufacturing, this source of energy doesn't always maintain at the constant level, but someday it will be run out. People must find ways to use other energy sources like nuclear electricity, wind power, public transportation such as electric metro use this kind of energy. As a result, governments will invest more in researching the usage of these new types of energy, travelling using public transport will become much cheaper than personal car. The range of public services will be expanded to remote areas, it become much more convenient. Commuting, travelling using public services will be the first prior choice.\n\nIn addition, the population is increasing rapidly, in the near future, almost cities in the world will become overcrowded, It will lead to many other problems to be solved, for example, traffic jams, air pollution. In some perespectives, these problems are led by the over-using of personal vehicle. People will realise that using their own car is just to waste their time for traffic jams or to gradually killing themselves for breathing polluted air. So the best way for them to travel is public transpotations. \nEach person should practice saving when using any source of eneny to protect his own live. \nIn conclusion, investments on developing public transport will be increased considerably, public services of travelling have a bright potential and it's stance in the future can't be replaceble\n\nTweak this document to make it read smoother and fix any mistakes you see.", "output": "Despite the fact that personal vehicles like the car are now popular, in the future, the existence of public transport will still absolutelyery crucial. \n\nNowadays, people use fossil fuels as the most important source of energy for travelling and manufacturing. This source of energy doesn't always remain at a constant level, and someday it will run out. People must find ways to use other energy sources like nuclear power and wind power. Public transportation such as the electric metro use this kind of energy. As a result, governments will invest more in researching the use of these new types of energy, and travelling using public transport will become much cheaper than privatepersonal car. The range of public services will be expanded to remote areas, and it will become much more convenient. Commuting and travelling using public services will be the first priority.\n\nIn addition, the population is increasing rapidly. In the near future, most cities in the world will become overcrowded. This will lead to many other problems that will need to be solved. For example, traffic jams and air pollution. In some ways, these problems are caused by the overuse of personal vehicles. People will realise that using their own car is just wasting their time in traffic jams or gradually killing themselves by breathing polluted air. So the best way for them to travel is public transport. \nEach person should practice saving energy when using any source of energy to protect his own life. \nIn conclusion, investments in developing public transport will be increased considerably. Public transport services have a bright future and their existence in the future can't be replaced.\n", "source": "BEA-GEC", "subset": "main" }, { "input": "What can you tell me about Nang Ka Teih? Respond in 58 words.", "output": "Founded by the people of Chin (Zo), Nang Ka Teih is a town ward within Sagain Division or next to Khampat town in the Myanmar state of Chin and Sagain Division. It has a very high concentration of the religion. Nang Ka Teih is one of the best location in Chin State and listed among the peace area.", "source": "Wikipedia", "subset": "chatbot" }, { "input": "I'm having a serious argument with a friend on the status of the word 'suck' when I used it about him by saying 'You suck!' because he missed a train. We are both non-native English speakers. He claimed that I used it as a foul or slang term. I vehemently disagreed with him about it. I told him that the word 'suck' can be used to describe something inefficient or not good enough as well like it is used in the example 'Samsung mobile sucks'. Am I right to describe the use of the word 'suck' in such context?.", "output": "Yes, you can use the word in the way that you have described, but it's considered more harsh than polite, and it has somewhat vulgar overtones. How it's regarded or received might be generational. I typed is suck vulgar? on Google, and found mixed responses. Feel free to do the same if you want diverse opinions on the matter. I thought this excerpt from a blog post, though, was worth pasting into an answer here: Some may not believe this, but suck as in 'Man, this class sucks' was also in the raw obscenity category when I was a teenager. It was used plenty in the school hallways but not in front of your teacher and never in front of your mother. I remember some agitation by certain culturally-advanced youngsters who tried to railroad their elders into accepting sucks as a safe and harmless substitute for stinks. The elders weren't having any of it, last I checked, but the liberalizing linguists seem to have carried the day. I have always assumedrightly or wrongly, I do not know that the word was originally intended to carry sexual overtones, which was the reason for its suppression. Today, the sexual overtones are either forgotten or are now acceptable in mixed company. I'm not sure which explanation disturbs me more. I think you and your friend are unlikely to come up with an agreed-upon viewpoint, because you're both right in a way. Feel free to use it on message boards and the like when you want to express a negative opinion, but realize you'll risk sounding a bit uncouth to some when you do. Then again, maybe I'm just showing my age here. As a footnote, you might want to check out our sister site, English Language Learners.", "source": "Natural Instructions", "subset": "task225 english language answer generation" }, { "input": "How is BeeHighve Inc. in Corner Brook, NL infusing cannabis with honey products and bringing them to markets, both local and global?", "output": "BeeHighve CEO Rita Hall intends to bring Newfoundland honey and honey-based products to the market, some infused with cannabis, while others will be \"buzz free.\"\nA selection of some of the products available from Corner Brook, N.L.-based BeeHighve Inc.\nThere's a lot of buzz around a joint venture between two Newfoundland companies that want to bring cannabis-infused honey and honey products to local and global markets.\nBeeHighve Inc., based in Corner Brook, will be cultivating the cannabis crop, while G and M Family Farms, near Placentia, will supply the honey.\n\"It ranges from pure honey to sauces as well as chocolate and health bars, and everything is organic,\" said BeeHighve CEO Rita Hall. \"And everything is very healthy for you. We don't use sugar in our products.\"\nThe partnership is the brainchild of Hall, a trailblazer who is on track to become the first Indigenous woman to gain licensed producer [of marijuana] status in Canada, as well as one of the country's few Indigenous female CEOs.\nHall intends for Newfoundland honey to be the backbone of her operation, and utilizes it in all the products, including the flagship Nuts About Honey bars.\nAlthough no official date has been set for the legalization of cannabis-infused edibles \u2014 with recreational marijuana use legal as of Oct. 17, 2018 \u2014 BeeHighve plans to produce the same line of products, without the cannabis elements before and after the legalization of edibles.\n\"I don't think it's going to have a negative impact on the business at all. The honey is really generating a lot of interest,\" Hall said.\nPart of the interest is because of Newfoundland's uniquely thriving honey-bee population, who enjoy the benefits of a closed ecosystem comparatively free from mites and diseases associated with colony collapse.\n\"They love the idea of mite-free, antibiotic-free honey. So I have no doubt that the production and sale of honey and the consumables without cannabis infusions will go very well.\"\nEventually BeeHighve intends to get into the beekeeping business as well, allowing them to produce larger volumes of product in less time.\nAside from the plans to export the cannabis-infused products \u2014 where legal \u2014 as well as the \"buzz-free\" ones globally, BeeHighve is looking to expand its production to another province.\nPartnering with the Madawaska Maliseet First Nation reserve in New Brunswick, crops will be cultivated on the reserve to later be infused with Newfoundland honey. Hall believes the partnership will be a fruitful one, and has great respect for the Madawaska Maliseet, whose senior leadership is entirely made of women.\n\"It shows the strength of women in any marketplace. Women are really underrepresented in the cannabis industry right now, so it's, I'll say, a feather in our caps, no pun intended, to be a part of this industry as well.\"\nHall has just as much faith in this venture as the one in Newfoundland.\n\"We're very strong women and I think very successful, and we'll succeed at this as well.\"", "source": "C4", "subset": "search-engine" }, { "input": "Respond briefly. What is the history of the Eagle Ranger Station?", "output": "The Eagle Ranger Station, also known as the Eagle Guard Station and presently known as the Sol Duc Ranger Station, is a complex of three buildings built in the 1930s in what would become Olympic National Park. The primary structures were built by the U.S. Forest Service in what was at the time the Olympic National Forest., While the main residence was built by the USFS, the generating plant and landscaping were built by the National Park Service using labor provided by the Civilian Conservation Corps.", "source": "Wikipedia", "subset": "chatbot" }, { "input": "How to deal with non negotiable change\n\nOne of the issues my organisation has is where change requests are requested during a sprint cycle in a particular high volume worksteam. The development team are often forced to accept these change requests, where any resistance results in an escalation, which results with somebody high up in the business just telling the dev team to do it. Thus, team is constantly over-committing every sprint cycle.\nI have highlighted this as an issue to Senior management, showing the impact the additional work is having on morale and quality, but nothing seems to be changing at org level, and they seem to think the problem is at team level. Hence, keep trying to address the issue by adding and removing Scrum Masters; where they all are having the same problem.\nWhat is the best way to address this?\n\nA:\n\nTL;DR\nYou have both a political problem and a process problem. There is no silver bullet, and you can't fix these problems solely within the team. Instead, you must rely on the agile concepts of transparency and visibility to ensure that the costs associated with bypassing the process are charged to the project budget.\nMake Costs Visible\nIn agile frameworks, change is never non-negotiable. If you are faced with \"non-negotiable\" requirements that make the goal or plan for the current iteration obsolete, then the team must scrap the plan.\nThis is true regardless of the agile framework you're using. In Scrum, the Product Owner cancels the Sprint and the Scrum Team returns to Sprint Planning. In Kanban, work is pulled off the board to meet work-in-progress (WIP) limits, and queued in the appropriate bucket until capacity is once again available. Your specific framework may vary in the way it implements this freeing up of capacity, but it's an essential component of any viable agile (and even non-agile) process.\nThe cost of replanning and re-queuing is made visible through metrics such as cycle time, cumulative flow, velocity, burn-down, or other framework- and implementation-specific metrics. Communicating about these costs is also essential. By presenting the data to stakeholders and company officers, you place the responsibility for those costs squarely where it belongs.\nSchedule risk (like any risk) can be accepted, transferred, or mitigated. Which of these three options is best is a business decision, not a team decision. Regardless of the choice, senior management owns both the choice and the outcome.\nSee Also\nhttps://pm.stackexchange.com/a/10829/4271", "output": "I agree with Todd about it being a political and process problem.\nI want to add that apart from showing them the impact on morale and quality -which I believe that they are quite important for your team- I would recommend having a discussion with them to understand why they believe it is useful to have this way of working and also prepare for them a \"business case\" that explores what is they key impact of this way of working for the senior management (trying to talk their own language basically). \nMost of the times, this comes down to how much money they are losing by having this way of working. Examples of the impact to senior management are:\n\nDelayed features due to the introduction of these changes \nBugs introduced due to quality issues that required to introduce more tasks in future releases.", "source": "StackExchange", "subset": "pm" }, { "input": "Use of LaTeX commands\n\nIt appears that LaTeX commands cannot be used on SE Linguistics, but they can on other SE sites. Personally, I wanted to use this feature a number of times to insert a formula and here's another question with the same problem. Expressing formulae can be useful when discussing things as simple as the number of combinations that arise from the different conditions in an experiment, or when discussing statistics questions relevant to linguistics. \nCould we please allow the use of LaTeX commands? I don't think having this feature will be a problem for anyone not wanting to use it and it will help those who do.\n\nA:\n\nAbsolutely subscribing here!\nI am surprised that this is not a feature already, I am really missing TeX support - for a wide range of uses that I consider essential in the field of linguistics: \n\nMathematical formulas.\nAs someone who frequently answers questions on formal semantics, like here, here and here, I feel that answering (and asking) these questions in the current state is a mess, to an extent where it has sometimes kept me from writing answers to those questions altogether, simply because the process of setting them up is so annoying.\nWhile I am aware that there are tools which ease the inclusion of Unicode symbols as HTML, all of these solutions (at least the ones I am aware of) still make typing more than one line of math quite cumbersome given that every symbol has to be clicked on individually, but worse than that, many characters and necessary formatting options, like pretty much half of the inventory I needed in this linguistically originated question, aren't even available as predefined Unicode symbols at all. Not to mention the fact that the output is visually not particularly pleasing and often makes the question hard to read, as can well be seen in this post.\nBesides, I doubt that the average user is aware of the existence of such tools; it took me more than a year on SE to find out about these, while new users have repeatedly been asking for LaTeX support, because this what users are used to.\nApart from formulas for questions on logic, MathJax would prove useful in questions involving statistics or combinatorics, which time and time again turn up this site. \nIPA, as already discussed here and here, to frequently be used in posts like here, here, here and many posts by other users.\nAt this point I am not sure whether enabling MathJax will also provide the opportunity to support other LaTeX packages, but if so, being able to use tipa would be of great help in typing IPA symbols.\nI am aware of IPA typeit which I have used so far, and some other external tools, but again, clicking is cumbersome, and users will be more comfortable writing questions and answers with the tools thy are used to (and the standard way of typesetting IPA symbols is LaTeX tipa). \nOther, less urgent cases where I could imagine TeX support to be useful - again under the premise that support for special packages is in principal possible - include\n\ntree structures, most importantly syntax trees, like here, here and here; a similar request was already made on meta here.\nlinguistic example formatting, which I could have used here, and interlinear glosses for posts like here, which are rather inconvenient to do appropriately without automatic alignment.\nThe current solution to these use cases is to typeset them externally and then include them in the post via a screenshot - I could live with this solution for these particular cases, but if MathJax comes, then this would certainly be a nice for these less frequent use cases as well. \n\nA:\n\nIt won't be a problem, but we need to show that the site actually needs it. \nOn Physics SE, the same question has been asked and you can see it got at least 32 votes. If there is not enough support from the community, which means enough people that want it, then it's probably not gonna happen. It's the same as for Migration paths in other sites, you need to prove there is a case for it to be implemented.\nIf it gets enough support, it'll be added. ", "output": "A clear yes. And do not forget to include tipa support for IPA in LaTeX, I find tipa input from the keyboard very convenient (compared to searching, cutting, and pasting IPA characters from a Unicode character table).", "source": "StackExchange", "subset": "linguistics" }, { "input": "In this task, you are given text for US Congressional and California state bills, your task is to generate a summary for this bill.\n\"SECTION 1. SHORT TITLE.\n\n This Act may be cited as the ``Sequestration Relief Act of 2013''.\n\nSEC. 2. FINDINGS AND PURPOSE.\n\n (a) Findings.--Congress finds the following:\n (1) Congress must enact a comprehensive, deficit reduction \n plan to solve the country's fiscal challenges and to promote \n national security, economic stability, and the continued growth \n and prosperity of the United States.\n (2) The keys to a comprehensive, deficit reduction solution \n are increased revenues and changes in mandatory spending.\n (3) The Budget Control Act of 2011 was enacted to avert a \n default on Federal debt obligations, and it reduced \n discretionary spending by approximately $1 trillion through \n fiscal year 2021.\n (4) Because the Joint Select Committee on Deficit Reduction \n failed to recommend legislation providing an additional $1.2 \n trillion in deficit reduction, Federal law mandates that the \n additional savings be sequestered.\n (5) Sequestration was designed as a forcing mechanism for \n an agreement on a comprehensive, deficit reduction plan. It has \n failed to produce the intended results.\n (6) It no longer makes sense to rely on sequestration as a \n forcing mechanism for a balanced solution. The costs to our \n government and to the economy are too great.\n (7) Under sequestration, automatic, indiscriminate cuts \n would be applied, through fiscal year 2021, to a wide variety \n of discretionary spending programs to achieve $1.2 trillion in \n savings, forestalling the sound planning needed for prudent and \n meaningful investments in national security, the workforce, \n transportation infrastructure, education, health care, public \n safety, housing, innovation, small business development, and \n many other facets of enduring national strength.\n (8) Even the prospect of sequestration is disruptive to \n regular order and to the congressional appropriations process, \n and it fosters damaging economic uncertainty, while short-term \n solutions only suspend the prospect and continue to undermine \n the certainty needed for economic recovery.\n (9) Therefore, Congress must eliminate the threat of \n sequestration.\n (10) Given the magnitude of the Federal deficit, it is \n likely that additional cuts to discretionary spending will be \n necessary for a comprehensive deficit reduction solution.\n (11) Congress must establish a manageable, long-term \n discretionary spending plan. An additional $320 billion in \n targetable cuts to discretionary appropriations from fiscal \n year 2014 through fiscal year 2021 represents one-third of the \n net amount that would have been indiscriminately cut by \n sequestration over fiscal years 2013 through 2021.\n (12) It is recognized that a reduction of $167 billion to \n discretionary appropriations within budget function 050 from \n fiscal year 2014 through fiscal year 2021 will affect the \n National Military Strategy. The Department of Defense is highly \n encouraged to revisit its current strategic guidance and to \n work closely with Congress in building a new National Military \n Strategy that accounts for available resource levels.\n (b) Purposes.--The purposes of this Act are to--\n (1) eliminate the threat of sequestration to the American \n economy;\n (2) offer the Federal Government, industry, and the \n American people the predictability that economic recovery \n demands;\n (3) enable the Congress to pass appropriations legislation \n in regular order with a clear discretionary spending budget and \n grant the legislative and executive branches of government the \n flexibility needed to identify and implement specific \n discretionary spending reductions in a responsible and \n deliberate manner; and\n (4) provide a practicable, long-term discretionary spending \n plan that will contribute to a comprehensive, balanced, long-\n term, deficit reduction solution that includes affordable \n revisions to mandatory spending and new revenues.\n\nSEC. 3. REPEAL OF SECTION 251A SEQUESTRATIONS.\n\n Section 251A of the Balanced Budget and Emergency Deficit Control \nAct of 1985 is repealed.\n\nSEC. 4. $320 BILLION REDUCTION IN DISCRETIONARY SPENDING LIMITS.\n\n The discretionary spending limits set forth in paragraphs (3) \nthrough (10) of section 251(c) of the Balanced Budget and Emergency \nDeficit Control Act of 1985 are amended to read as follows:\n ``(3) for fiscal year 2014--\n ``(A) for the security category, $546,000,000,000 \n in budget authority; and\n ``(B) for the nonsecurity category, \n $501,000,000,000 in budget authority;\n ``(4) with respect to fiscal year 2015--\n ``(A) for the security category, $550,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $505,000,000,000 in new budget authority;\n ``(5) with respect to fiscal year 2016--\n ``(A) for the security category, $559,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $513,000,000,000 in new budget authority;\n ``(6) with respect to fiscal year 2017--\n ``(A) for the security category, $569,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $522,000,000,000 in new budget authority;\n ``(7) with respect to fiscal year 2018--\n ``(A) for the security category, $579,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $531,000,000,000 in new budget authority;\n ``(8) with respect to fiscal year 2019--\n ``(A) for the security category, $589,500,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $541,000,000,000 in new budget authority;\n ``(9) with respect to fiscal year 2020--\n ``(A) for the security category, $602,500,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $553,000,000,000 in new budget authority;\n ``(10) with respect to fiscal year 2021--\n ``(A) for the security category, $616,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $565,000,000,000 in new budget authority;''.\n\nSEC. 5. DEFINITION OF SECURITY CATEGORY.\n\n Section 250(c)(4)(B) of the Balanced Budget and Emergency Deficit \nControl Act of 1985 is amended to read as follows:\n ``(B)(i) For fiscal years 2012 and 2013, the term `security \n category' means discretionary appropriations associated with \n agency budgets for the Department of Defense, the Department of \n Homeland Security, the Department of Veterans Affairs, the \n National Nuclear Security Administration, the intelligence \n community management account (95-0401-0-1-054), and all budget \n accounts in budget function 150 (international affairs).\n ``(ii) For fiscal years 2014 through 2021, the term \n `security category' means discretionary appropriations in \n budget function 050 (national defense).''.\n\nSEC. 6. SUSPENSION OF STATUTORY LIMIT ON THE PUBLIC DEBT UNTIL FEBRUARY \n 1, 2017.\n\n Section 2 of the No Budget, No Pay Act of 2013 is amended--\n (1) in subsection (a), by striking ``May 18, 2013'' and \n inserting ``January 31, 2017''; and\n (2) in subsection (b), by striking ``May 19, 2013'' each \n place it appears and inserting ``February 1, 2017''.\"", "output": "Sequestration Relief Act of 2013 - Amends the Balanced Budget and Emergency Deficit Controlled Act of 1985 (Gramm-Rudman-Hollings Act), as amended by the Budget Control Act of 2011, to repeal automatic sequestration to enforce a specified budget goal. Reduces the discretionary spending limits for security and nonsecurity categories by $320 billion for FY2014-FY2021. Confines the current meaning of "security category" under the Gramm-Rudman-Hollings Act to FY2012-FY2013. Revises the definition of "security category" for FY2014-FY2021 to mean discretionary appropriations in all of budget function 050 (national defense). Amends the No Budget, No Pay Act of 2013 to suspend through January 31, 2017, the current $16.394 trillion public debt limit. Postpones until February 1, 2017, an automatic increase in the public debt limit to the extent that: (1) the face amount of obligations issued and the face amount of obligations whose principal and interest are guaranteed by the federal government (except guaranteed obligations held by the Secretary of the Treasury) outstanding on February 1, 2017, exceeds (2) the face amount of such obligations outstanding on the date of enactment of the No Budget, No Pay Act of 2013 (February 24, 2013). Prohibits an obligation from being taken into account unless its issuance was necessary to fund a commitment incurred by the federal government that required payment before February 1, 2017.", "source": "Natural Instructions", "subset": "task1658 billsum summarization" } ] ================================================ FILE: tests/data/test_alpaca.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from litgpt.data import Alpaca from litgpt.prompts import Alpaca as AlpacaPromptStyle def test_alpaca(mock_tokenizer, alpaca_path): alpaca = Alpaca(val_split_fraction=0.5, download_dir=alpaca_path.parent, file_name=alpaca_path.name, num_workers=0) assert isinstance(alpaca.prompt_style, AlpacaPromptStyle) alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10) alpaca.prepare_data() alpaca.setup() train_dataloader = alpaca.train_dataloader() val_dataloader = alpaca.val_dataloader() assert len(train_dataloader) == 6 assert len(val_dataloader) == 6 train_batch = next(iter(train_dataloader)) val_batch = next(iter(val_dataloader)) assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"} for key in ["input_ids", "labels"]: assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]" assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]" assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle) assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle) # has attributes from super class `LightningDataModule` assert alpaca.prepare_data_per_node ================================================ FILE: tests/data/test_base.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from typing import Optional import pytest import torch from litgpt.data.base import SFTDataset, get_sft_collate_fn from litgpt.prompts import PromptStyle @pytest.mark.parametrize("mask_prompt", [True, False]) @pytest.mark.parametrize("ignore_index", [-1, -100]) @pytest.mark.parametrize("max_seq_length", [1000, 5, -1]) def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer): class Style(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str: return f"In: {prompt} Out:" i = ignore_index data = [{"instruction": "Foo", "output": "Bar"}, {"instruction": "Boo", "output": "Ahh"}] dataset = SFTDataset( data=data, tokenizer=mock_tokenizer, prompt_style=Style(), mask_prompt=mask_prompt, ignore_index=ignore_index, max_seq_length=max_seq_length, ) assert len(dataset) == len(data) expected_input_ids = torch.tensor([73, 110, 58, 32, 70, 111, 111, 32, 79, 117, 116, 58, 66, 97, 114, 1]) # If prompt is not masked, labels == input_ids expected_labels = ( torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids ) if max_seq_length == -1: assert torch.equal(dataset[0]["input_ids"], expected_input_ids) assert torch.equal(dataset[0]["labels"], expected_labels) else: assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length]) assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length]) @pytest.mark.parametrize("ignore_index", [-1, -100]) @pytest.mark.parametrize("pad_id", [0, 100]) def test_sft_collate_fn_padding(pad_id, ignore_index): collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index) samples = [ { "input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30]), "token_counts": {"raw": 3, "raw_plus_prompt_template": 25}, }, { "input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80]), "token_counts": {"raw": 5, "raw_plus_prompt_template": 27}, }, ] expected = { "input_ids": torch.tensor([[1, 2, 3, pad_id, pad_id], [4, 5, 6, 7, 8]]), "labels": torch.tensor([[10, 20, 30, ignore_index, ignore_index], [40, 50, 60, 70, 80]]), "token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])}, } batch = collate(samples) assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels")) for key in ("raw", "raw_plus_prompt_template"): assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}" def test_sft_collate_fn_truncation(): collate = get_sft_collate_fn(max_seq_length=2) samples = [ { "input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30]), "token_counts": {"raw": 3, "raw_plus_prompt_template": 25}, }, { "input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80]), "token_counts": {"raw": 5, "raw_plus_prompt_template": 27}, }, ] expected = { "input_ids": torch.tensor([[1, 2], [4, 5]]), "labels": torch.tensor([[10, 20], [40, 50]]), "token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])}, } batch = collate(samples) assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels")) for key in ("raw", "raw_plus_prompt_template"): assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}" ================================================ FILE: tests/data/test_deita.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from unittest import mock from litgpt.data import Deita, SFTDataset from litgpt.data.deita import format_dataset from litgpt.prompts import Alpaca as AlpacaPromptStyle def test_format_dataset(): data = [ { "prompt": "prompt1", "prompt_id": "1", "messages": [ {"content": "question1", "role": "user"}, {"content": "response1", "role": "assistant"}, {"content": "question2", "role": "user"}, {"content": "response2", "role": "assistant"}, ], }, { "prompt": "prompt2", "prompt_id": "2", "messages": [ {"content": "question3", "role": "user"}, {"content": "response3", "role": "assistant"}, {"content": "question4", "role": "user"}, {"content": "response4", "role": "assistant"}, ], }, ] assert format_dataset(data, include_multi_turn_conversations=False) == [ {"instruction": "question1", "output": "response1", "input": ""}, {"instruction": "question3", "output": "response3", "input": ""}, ] assert format_dataset(data, include_multi_turn_conversations=True) == [ {"instruction": "question1", "output": "response1", "input": ""}, {"instruction": "question2", "output": "response2", "input": ""}, {"instruction": "question3", "output": "response3", "input": ""}, {"instruction": "question4", "output": "response4", "input": ""}, ] @mock.patch("litgpt.data.deita.format_dataset") @mock.patch("datasets.load_dataset") def test_deita(_, format_dataset_mock, mock_tokenizer, tmp_path): format_dataset_mock.return_value = [ {"instruction": "inst1", "output": "out1"}, {"instruction": "inst2", "output": "out2"}, {"instruction": "inst3", "output": "out3"}, ] deita = Deita(num_workers=0, download_dir=tmp_path) assert isinstance(deita.prompt_style, AlpacaPromptStyle) deita.connect(mock_tokenizer, batch_size=2, max_seq_length=10) deita.prepare_data() deita.setup() train_dataloader = deita.train_dataloader() assert isinstance(train_dataloader.dataset, SFTDataset) assert len(train_dataloader) == 2 val_dataloader = deita.val_dataloader() assert isinstance(val_dataloader.dataset, SFTDataset) assert len(val_dataloader) == 2 assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle) assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle) # has attributes from super class `LightningDataModule` assert deita.prepare_data_per_node ================================================ FILE: tests/data/test_json.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json from typing import Optional import pytest from litgpt.data import JSON from litgpt.prompts import PromptStyle @pytest.mark.parametrize("as_jsonl", [False, True]) def test_json(as_jsonl, tmp_path, mock_tokenizer): class Style(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str: return f"X: {prompt} {kwargs['input']} Y:" json_path = tmp_path / ("data.jsonl" if as_jsonl else "data.json") mock_data = [ {"instruction": "Add", "input": "2+2", "output": "4"}, {"instruction": "Subtract", "input": "5-3", "output": "2"}, {"instruction": "Multiply", "input": "6*4", "output": "24"}, {"instruction": "Divide", "input": "10/2", "output": "5"}, {"instruction": "Exponentiate", "input": "2^3", "output": "8"}, {"instruction": "Square root", "input": "√9", "output": "3"}, ] with open(json_path, "w", encoding="utf-8") as fp: if as_jsonl: for line in mock_data: json.dump(line, fp) fp.write("\n") else: json.dump(mock_data, fp) data = JSON(json_path, val_split_fraction=0.5, prompt_style=Style(), num_workers=0) data.connect(tokenizer=mock_tokenizer, batch_size=2) data.prepare_data() # does nothing data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() assert len(train_dataloader) == 2 assert len(val_dataloader) == 2 train_data = list(train_dataloader) val_data = list(val_dataloader) assert train_data[0]["input_ids"].size(0) == 2 assert train_data[1]["input_ids"].size(0) == 1 assert val_data[0]["input_ids"].size(0) == 2 assert val_data[1]["input_ids"].size(0) == 1 assert mock_tokenizer.decode(train_data[0]["input_ids"][0]).startswith("X: Divide 10/2 Y:5") assert mock_tokenizer.decode(train_data[0]["input_ids"][1]).startswith("X: Add 2+2 Y:4") assert mock_tokenizer.decode(train_data[1]["input_ids"][0]).startswith("X: Multiply 6*4 Y:24") assert mock_tokenizer.decode(val_data[0]["input_ids"][0]).startswith("X: Exponentiate 2^3 Y:8") assert mock_tokenizer.decode(val_data[0]["input_ids"][1]).startswith("X: Subtract 5-3 Y:2") assert mock_tokenizer.decode(val_data[1]["input_ids"][0]).startswith("X: Square root √9 Y:3") assert isinstance(train_dataloader.dataset.prompt_style, Style) assert isinstance(val_dataloader.dataset.prompt_style, Style) # has attributes from super class `LightningDataModule` assert data.prepare_data_per_node def test_json_input_validation(tmp_path): with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"): JSON(tmp_path / "not exist") with pytest.raises(ValueError, match="`val_split_fraction` should not be set"): JSON(tmp_path, val_split_fraction=0.5) data = JSON(tmp_path) data.prepare_data() # does nothing # Empty directory with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"): data.setup() # Only train.json exists (tmp_path / "train.json").touch() with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"): data.setup() # When a single file is passed without val_split_fraction, it defaults to 0.05 and warns. with pytest.warns(UserWarning, match="Defaulting to `val_split_fraction=0.05`"): data = JSON(tmp_path / "train.json", val_split_fraction=None) assert data.val_split_fraction == 0.05 @pytest.mark.parametrize("as_jsonl", [False, True]) def test_json_with_splits(as_jsonl, tmp_path, mock_tokenizer): mock_train_data = [ {"instruction": "Add", "input": "2+2", "output": "4"}, {"instruction": "Subtract", "input": "5-3", "output": "2"}, {"instruction": "Exponentiate", "input": "2^3", "output": "8"}, ] mock_test_data = [ {"instruction": "Multiply", "input": "6*4", "output": "24"}, {"instruction": "Divide", "input": "10/2", "output": "5"}, ] train_file = tmp_path / ("train.jsonl" if as_jsonl else "train.json") val_file = tmp_path / ("val.jsonl" if as_jsonl else "val.json") with open(train_file, "w", encoding="utf-8") as fp: if as_jsonl: for line in mock_train_data: json.dump(line, fp) fp.write("\n") else: json.dump(mock_train_data, fp) with open(val_file, "w", encoding="utf-8") as fp: if as_jsonl: for line in mock_test_data: json.dump(line, fp) fp.write("\n") else: json.dump(mock_test_data, fp) data = JSON(tmp_path, num_workers=0) data.connect(tokenizer=mock_tokenizer, batch_size=2) data.prepare_data() # does nothing data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() assert len(train_dataloader) == 2 assert len(val_dataloader) == 1 ================================================ FILE: tests/data/test_lit_data.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys from unittest import mock from unittest.mock import ANY import pytest from litgpt.data import LitData @pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining") @mock.patch("litgpt.data.lit_data.LitData._dataloader") def test_input_dir_and_splits(dl_mock, tmp_path): with pytest.raises(ValueError, match="If provided `split_names` must be a tuple of two strings"): LitData(data_path=tmp_path, split_names=("train",)) # local dir, no splits data = LitData(data_path=tmp_path) data.train_dataloader() dl_mock.assert_called_with(input_dir=str(tmp_path), train=True) data.val_dataloader() dl_mock.assert_called_with(input_dir=str(tmp_path), train=False) # local dir, splits data = LitData(data_path=tmp_path, split_names=("train", "val")) data.train_dataloader() dl_mock.assert_called_with(input_dir=str(tmp_path / "train"), train=True) data.val_dataloader() dl_mock.assert_called_with(input_dir=str(tmp_path / "val"), train=False) # remote dir, splits data = LitData(data_path="s3://mydataset/data", split_names=("train", "val")) data.train_dataloader() dl_mock.assert_called_with(input_dir="s3://mydataset/data/train", train=True) data.val_dataloader() dl_mock.assert_called_with(input_dir="s3://mydataset/data/val", train=False) @pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining") @mock.patch("litdata.streaming.StreamingDataset") @mock.patch("litdata.streaming.StreamingDataLoader") def test_dataset_args(streaming_dataloader_mock, streaming_dataset_mock, tmp_path): data = LitData(data_path=tmp_path, seed=1000) data.train_dataloader() streaming_dataset_mock.assert_called_with( input_dir=str(tmp_path), item_loader=ANY, shuffle=True, seed=1000, ) streaming_dataloader_mock.assert_called_with( streaming_dataset_mock(), batch_size=1, pin_memory=True, num_workers=8, drop_last=True, ) ================================================ FILE: tests/data/test_longform.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from litgpt.data import LongForm from litgpt.prompts import Longform as LongFormPromptStyle def test_longform(mock_tokenizer, longform_path): longform = LongForm(download_dir=longform_path, num_workers=0) assert isinstance(longform.prompt_style, LongFormPromptStyle) longform.connect(mock_tokenizer, batch_size=2, max_seq_length=10) longform.prepare_data() longform.setup() train_dataloader = longform.train_dataloader() val_dataloader = longform.val_dataloader() assert len(train_dataloader) == 9 assert len(val_dataloader) == 5 train_batch = next(iter(train_dataloader)) val_batch = next(iter(val_dataloader)) assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"} for key in ["input_ids", "labels"]: assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]" assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]" assert isinstance(train_dataloader.dataset.prompt_style, LongFormPromptStyle) assert isinstance(val_dataloader.dataset.prompt_style, LongFormPromptStyle) # has attributes from super class `LightningDataModule` assert longform.prepare_data_per_node ================================================ FILE: tests/data/test_openwebtext.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys from unittest import mock from unittest.mock import ANY, call import pytest from litdata.streaming import StreamingDataLoader, StreamingDataset from torch.utils.data import DataLoader from litgpt.data import OpenWebText @pytest.mark.skipif(sys.platform == "win32", reason="Not in the mood to add Windows support right now.") @mock.patch("litdata.optimize") @mock.patch("litdata.streaming.dataset.subsample_streaming_dataset", return_value=([], [])) @mock.patch("datasets.load_dataset") def test_openwebtext(_, __, optimize_mock, tmp_path, mock_tokenizer): data = OpenWebText(data_path=(tmp_path / "openwebtext")) assert data.seq_length == 2048 assert data.batch_size == 1 data.connect(tokenizer=mock_tokenizer, batch_size=2, max_seq_length=1024) assert data.seq_length == 1025 assert data.batch_size == 2 # Data does not exist, preprocess it data.prepare_data() optimize_mock.assert_has_calls( [ call( fn=ANY, num_workers=ANY, inputs=[], output_dir=str(tmp_path / "openwebtext" / "train"), chunk_bytes="200MB", ), call( fn=ANY, num_workers=ANY, inputs=[], output_dir=str(tmp_path / "openwebtext" / "val"), chunk_bytes="200MB", ), ] ) optimize_mock.reset_mock() # Data exists, already preprocessed (tmp_path / "openwebtext" / "train").mkdir(parents=True) (tmp_path / "openwebtext" / "val").mkdir(parents=True) data.prepare_data() optimize_mock.assert_not_called() data.setup() train_dataloader = data.train_dataloader() assert isinstance(train_dataloader, StreamingDataLoader) assert isinstance(train_dataloader.dataset, StreamingDataset) val_dataloader = data.val_dataloader() assert isinstance(val_dataloader, DataLoader) assert isinstance(val_dataloader.dataset, StreamingDataset) # has attributes from super class `LightningDataModule` assert data.prepare_data_per_node ================================================ FILE: tests/data/test_textfiles.py ================================================ import json import torch from litdata import TokensLoader, optimize from torch.utils._pytree import tree_map from litgpt.data.text_files import TextFiles class Tokenizer: bos_id = 0 def encode(self, text, bos, eos): assert bos assert not eos return [self.bos_id] + [ord(c) for c in text] def tokenize(data): for story in data: yield torch.tensor(story) def fake_chunk(path, data): optimize( fn=tokenize, inputs=[data] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB", item_loader=TokensLoader(), ) def test_textfiles_datamodule(tmp_path): from litgpt.data.text_files import TextFiles data_dir = tmp_path / "textfiles" datamodule = TextFiles(train_data_path=data_dir, num_workers=1) datamodule.connect(max_seq_length=2, tokenizer=Tokenizer()) # simulate `datamodule.prepare_data` train_data_dir = data_dir / "train" train_data_dir.mkdir(parents=True) fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) datamodule.setup() tr_dataloader = datamodule.train_dataloader() tr_dataloader.shuffle = False actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) # there is 1 sample per index in the data (13) assert actual == [ [[73, 5, 0]], [[12, 0, 23]], [[5, 0, 1]], [[0, 73, 5]], [[1999, 0, 13]], [[0, 1, 1999]], [[1, 1999, 0]], [[0, 23, 15]], [[13, 12, 0]], [[63, 0, 73]], [[23, 15, 63]], [[15, 63, 0]], [[0, 13, 12]], ] class MockTokenizer: bos_id = 0 eos_id = 1 use_bos = True def encode(self, text, bos=True, eos=False, device=None, max_length=-1): # Simple: map each character to its ordinal + 2 tokens = [ord(c) + 2 for c in text] if bos: tokens = [self.bos_id] + tokens if eos: tokens.append(self.eos_id) if max_length > 0: tokens = tokens[:max_length] return torch.tensor(tokens, dtype=torch.long, device=device) def decode(self, tensor): ids = tensor.tolist() if tensor.ndim > 0 else [tensor.item()] chars = [] for tid in ids: if tid == self.bos_id: chars.append("") elif tid == self.eos_id: chars.append("") else: chars.append(chr(tid - 2)) return "".join(chars) def decode_stream(self, token_stream, device=None): for token in token_stream: yield self.decode(token) @property def vocab_size(self): return 130 def test_textfiles_token_loader(tmp_path): # Create the directory for text files data_dir = tmp_path / "textfiles" data_dir.mkdir(parents=True, exist_ok=True) # Write sample training data to the directory sample_texts = ["hello world", "foo bar", "lorem ipsum"] for i, text in enumerate(sample_texts): (data_dir / f"{i}.txt").write_text(text) datamodule = TextFiles(train_data_path=data_dir, num_workers=1) datamodule.connect(max_seq_length=2, tokenizer=MockTokenizer()) datamodule.prepare_data() # ensure training set uses tokens loader index_json = data_dir / "train" / "index.json" assert index_json.exists() meta = json.loads(index_json.read_text()) assert meta["config"]["item_loader"] == "TokensLoader" # ensure validation set uses tokens loader index_json = data_dir / "val" / "index.json" assert index_json.exists() meta = json.loads(index_json.read_text()) assert meta["config"]["item_loader"] == "TokensLoader" ================================================ FILE: tests/data/test_tinyllama.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from unittest import mock import pytest from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset from torch.utils.data import DataLoader from litgpt.data import TinyLlama @mock.patch("litdata.streaming.dataset.subsample_streaming_dataset", return_value=([], [])) def test_tinyllama(_, tmp_path): data = TinyLlama(data_path=(tmp_path / "data")) assert data.seq_length == 2048 assert data.batch_size == 1 data.connect(batch_size=2, max_seq_length=1024) assert data.seq_length == 1025 assert data.batch_size == 2 with pytest.raises(FileNotFoundError, match="The directory .*data/slimpajama/train does not exist"): data.prepare_data() (tmp_path / "data" / "slimpajama" / "train").mkdir(parents=True) (tmp_path / "data" / "slimpajama" / "val").mkdir(parents=True) (tmp_path / "data" / "starcoder").mkdir(parents=True) data.prepare_data() data.setup() train_dataloader = data.train_dataloader() assert isinstance(train_dataloader, StreamingDataLoader) assert isinstance(train_dataloader.dataset, CombinedStreamingDataset) val_dataloader = data.val_dataloader() assert isinstance(val_dataloader, DataLoader) assert isinstance(val_dataloader.dataset, StreamingDataset) # has attributes from super class `LightningDataModule` assert data.prepare_data_per_node ================================================ FILE: tests/data/test_tinystories.py ================================================ import json import pytest import torch from litdata import optimize from litdata.streaming import StreamingDataset, TokensLoader from torch.utils._pytree import tree_map def tokenize(data): for story in data: yield torch.tensor(story) def fake_chunk(path, data): optimize( fn=tokenize, inputs=[data] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB", item_loader=TokensLoader(), ) @pytest.mark.parametrize( ("max_seq_len", "expected"), [ (2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]), (5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]), (6, [[0, 23, 15, 63, 0, 73, 5]]), (7, [[0, 23, 15, 63, 0, 73, 5, 0]]), ], ) def test_pretok_dataset(tmp_path, max_seq_len, expected): fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13] assert len(fake_data) == 12 fake_chunk(tmp_path, [fake_data]) dataset = StreamingDataset( input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False ) actual = tree_map(torch.Tensor.tolist, list(dataset)) assert actual == expected def test_tokenize(tmp_path, monkeypatch): from litgpt.data.tinystories import tokenize story1, story2 = "foo bar", " fun " data = [{"story": story1}, {"story": story2}] shard_path = tmp_path / "data.json" with open(shard_path, "w", encoding="utf-8") as f: json.dump(data, f) class Tokenizer: bos_id = 0 def encode(self, text, bos, eos): assert bos assert not eos return [self.bos_id] + [ord(c) for c in text] monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", "0") monkeypatch.setenv("DATA_OPTIMIZER_NUM_WORKERS", "1") data = tokenize(str(shard_path), Tokenizer()) assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]] def test_tinystories_datamodule(tmp_path): from litgpt.data.tinystories import TinyStories data_dir = tmp_path / "tinystories" datamodule = TinyStories(data_dir, seed=42, num_workers=1) datamodule.connect(max_seq_length=2) # simulate `datamodule.prepare_data` train_data_dir = data_dir / "train" train_data_dir.mkdir(parents=True) fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) datamodule.setup() tr_dataloader = datamodule.train_dataloader() tr_dataloader.shuffle = False actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) # there is 1 sample per index in the data (13) assert actual == [ [[73, 5, 0]], [[12, 0, 23]], [[5, 0, 1]], [[0, 73, 5]], [[1999, 0, 13]], [[0, 1, 1999]], [[1, 1999, 0]], [[0, 23, 15]], [[13, 12, 0]], [[63, 0, 73]], [[23, 15, 63]], [[15, 63, 0]], [[0, 13, 12]], ] ================================================ FILE: tests/ext_thunder/__init__.py ================================================ import sys from pathlib import Path # support running without installing as a package, adding extensions to the Python path wd = Path(__file__).parent.parent.parent.resolve() if wd.is_dir(): sys.path.append(str(wd)) else: import warnings warnings.warn(f"Could not find extensions directory at {wd}") ================================================ FILE: tests/ext_thunder/test_thunder_distributed.py ================================================ import os import sys from pathlib import Path from typing import Optional, Tuple, Union import pytest import torch from lightning.fabric import Fabric from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from litgpt.constants import _THUNDER_AVAILABLE from litgpt.utils import _RunIf # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) if _THUNDER_AVAILABLE: from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy from extensions.thunder.strategies.thunder_fsdp import ThunderFSDPStrategy @_RunIf(thunder=True) def test_thunder_strategy_ddp_input_parsing(): with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): ThunderDDPStrategy(jit=False, executors=("python",)) @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.parametrize("choice", ["ddp", "fsdp"]) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_no_backward_sync_thunder(choice): if choice == "ddp": strategy = ThunderDDPStrategy() elif choice == "fsdp": strategy = ThunderFSDPStrategy() else: raise ValueError(f"Invalid choice: {choice}") fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) fabric.launch() # account for sharding in the case of FSDP out_features = 1 if "ddp" in choice else fabric.world_size model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device) x = torch.randn(1, 1, device=fabric.device) model = fabric.setup(model) # 6 iters, 3 grad accumulation iters for i, enabled in enumerate((True, True, False, True, True, False), 1): x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) with fabric.no_backward_sync(model, enabled): y = model(x) fabric.backward(y.sum()) if not enabled: # Math for the first 3 iters # # DistributedDataParallel # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9 # ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^ # rank0 rank1 allreduce # # thunder.distributed.ddp # ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9 # ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^ # rank0 rank1 allreduce1 rank0 rank1 allreduce2 assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) assert not hasattr(model.weight, "_thunder_fsdp_unsharded_grad") model.weight.grad = None elif choice == "fsdp": assert model.weight._thunder_fsdp_unsharded_grad.shape == (2, 1) assert model.weight.grad is None @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.parametrize("jit", (False, True)) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_jit_ddp_before_setup(jit): import thunder fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit)) fabric.launch() x = torch.randn(1, 1, device=fabric.device) model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) tmodel = thunder.jit(model) fmodel = fabric.setup(tmodel) fmodel(x) assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python() @_RunIf(min_cuda_gpus=1, thunder=True) def test_strategy_ddp_setup_already_traced(): import thunder device = torch.device("cuda") x = torch.randn(1, 1, device=device) model = torch.nn.Linear(1, 2, bias=False, device=device) strategy = ThunderDDPStrategy() tmodel = thunder.jit(model) tmodel(x) with pytest.raises(RuntimeError, match="already called"): strategy.setup_module(tmodel) @_RunIf(thunder=True) def test_thunder_strategy_fsdp_input_parsing(): from thunder.distributed import FSDPBucketingStrategy, FSDPType strategy = ThunderFSDPStrategy(bucketing_strategy="BlOcK", executors=("python",), sharding_strategy="zero3") assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK assert strategy.sharding_strategy is FSDPType.ZERO3 with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): ThunderFSDPStrategy(jit=False, executors=("python",)) @_RunIf(thunder=True) def test_save_checkpoint_invalid_settings_raise(tmp_path): strategy = ThunderFSDPStrategy(state_dict_type="full") with pytest.raises(TypeError, match="not supported"): strategy.save_checkpoint(tmp_path, {}, storage_options=object()) with pytest.raises(IsADirectoryError, match="path exists"): strategy.save_checkpoint(tmp_path, {}) model = torch.nn.Linear(1, 1) with pytest.raises(ValueError, match="Could not find"): strategy.save_checkpoint(tmp_path / "foo", {}) model.use_fsdp = True with pytest.raises(ValueError, match="Found multiple"): strategy.save_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) with pytest.raises(ValueError, match="at least a model"): strategy.load_checkpoint(tmp_path / "foo", {}) with pytest.raises(ValueError, match="must be a single file"): strategy.load_checkpoint(tmp_path, model) optimizer = torch.optim.Adam(model.parameters()) with pytest.raises(NotImplementedError, match="not supported"): strategy.load_checkpoint(tmp_path, optimizer) with pytest.raises(ValueError, match="Found multiple"): strategy.load_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) with pytest.raises(ValueError, match="Could not find"): strategy.load_checkpoint(tmp_path / "foo", {"foo": 1}) class Submodule(torch.nn.Module): def __init__(self, h: int): super().__init__() self.l = torch.nn.Linear(4, h * 2, bias=False) def forward(self, x): # defined just because preprocessing fails otherwise ... class MyModel(torch.nn.Module): def __init__(self, h: int): super().__init__() self.register_buffer("buf", torch.tensor(0)) self.l = torch.nn.Linear(2, h) self.inner = Submodule(h) def forward(self): # defined just because preprocessing fails otherwise ... def reset_parameters(self): self.buf = torch.empty_like(self.buf) @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_materialize_meta_tensors(): strategy = ThunderFSDPStrategy() fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() with fabric.init_module(empty_init=True): model = MyModel(2) model = fabric.setup(model) # all parameters were moved assert len(list(model.parameters())) == 3 assert all(p.device.type == "cuda" for p in model.parameters()) # buffers were moved too assert model.buf.device.type == "cuda" class StatefulThing: def state_dict(self): return {"thing": 1} def load_state_dict(self, state_dict): assert state_dict == self.state_dict() class TensorLike: def __init__(self, device: Optional[Union[str, torch.device]] = None, shape: Optional[Tuple[int, ...]] = None): self.device = torch.device(device) if device is not None else None self.shape = torch.Size(shape) if shape is not None else None def __eq__(self, other): return ( isinstance(other, torch.Tensor) and (self.device is None or other.device == self.device) and (self.shape is None or other.shape == self.shape) ) @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_save_load_full_checkpoint(tmp_path): strategy = ThunderFSDPStrategy(state_dict_type="full", broadcast_from=0) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() model = MyModel(4) expected = model.state_dict() # save a sharded model model = fabric.setup(model) state = {"model": model, "stateful": StatefulThing(), "primitive": 123} checkpoint_path = tmp_path / "foo" fabric.save(checkpoint_path, state) # assert the file contents if fabric.global_rank == 0: checkpoint = torch.load(checkpoint_path) # cpu_offload is enabled by default assert checkpoint == { "model": { "buf": TensorLike("cpu", tuple()), "inner.l.weight": TensorLike("cpu", (8, 4)), "l.bias": TensorLike("cpu", (4,)), "l.weight": TensorLike("cpu", (4, 2)), }, "stateful": {"thing": 1}, "primitive": 123, } torch.testing.assert_close(checkpoint["model"], expected) # load its weights into a different sharded model model = MyModel(4) model = fabric.setup(model) state = {"model": model, "stateful": StatefulThing(), "primitive": 321} fabric.load(checkpoint_path, state) from thunder.distributed import _unshard_params # unshard this model's parameters to compare with the original state dict before sharding _unshard_params(model, model.process_group_for_ddp, True) # we loaded rank 0's weights, so this would fail in the other ranks if fabric.global_rank == 0: actual = model.state_dict() # `_unshard_params` doesn't offload buffers at the moment assert actual["buf"].device.type == "cuda" actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) assert state["primitive"] == 123 @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_load_full_checkpoint_only_model(tmp_path): strategy = ThunderFSDPStrategy() fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() checkpoint_path = tmp_path / "foo" checkpoint_path = fabric.broadcast(checkpoint_path) if fabric.global_rank == 0: model = MyModel(4) expected = model.state_dict() torch.save(expected, checkpoint_path) fabric.barrier() expected = torch.load(checkpoint_path) # before sharding model = MyModel(4) fabric.load_raw(checkpoint_path, model) torch.testing.assert_close(model.state_dict(), expected) # after sharding model = MyModel(4) model = fabric.setup(model) fabric.load_raw(checkpoint_path, model) from thunder.distributed import _unshard_params # unshard this model's parameters to compare with the original state dict before sharding _unshard_params(model, model.process_group_for_ddp, True) actual = model.state_dict() # `_unshard_params` doesn't offload buffers at the moment assert actual["buf"].device.type == "cuda" actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) def distributed_ckpt_to_regular(path): """From ``torch.distributed.checkpoint.format_utils.dcp_to_torch_save``.""" from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint.state_dict_loader import _load_state_dict if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner else: from torch.distributed.checkpoint._traverse import set_element from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import TensorStorageMetadata class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def set_up_planner(self, state_dict, metadata, is_coordinator): assert not state_dict # rebuild the state dict from the metadata for k, v in metadata.state_dict_metadata.items(): if isinstance(v, TensorStorageMetadata): v = torch.empty(v.size, dtype=v.properties.dtype) if k in metadata.planner_data: set_element(state_dict, metadata.planner_data[k], v) else: state_dict[k] = v super().set_up_planner(state_dict, metadata, is_coordinator) state_dict = {} storage_reader = FileSystemReader(path) _load_state_dict(state_dict, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True) return state_dict @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_save_load_sharded_checkpoint(tmp_path): strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() model = MyModel(4) expected = model.state_dict() # save a sharded model model = fabric.setup(model) state = {"model": model, "stateful": StatefulThing(), "primitive": 123} fabric.save(tmp_path, state) # assert the file contents if fabric.global_rank == 0: assert set(os.listdir(tmp_path)) == {"meta.pt", "__1_0.distcp", "__0_0.distcp", ".metadata"} metadata = torch.load(tmp_path / "meta.pt") assert metadata == {"stateful": {"thing": 1}, "primitive": 123} checkpoint = distributed_ckpt_to_regular(tmp_path) # cpu_offload is enabled by default assert checkpoint == { "model": { "buf": TensorLike("cpu", tuple()), "inner.l.weight": TensorLike("cpu", (8, 4)), "l.bias": TensorLike("cpu", (4,)), "l.weight": TensorLike("cpu", (4, 2)), } } torch.testing.assert_close(checkpoint["model"], expected) # load its weights into a different sharded model model = MyModel(4) model = fabric.setup(model) state = {"model": model, "stateful": StatefulThing(), "primitive": 321} fabric.load(tmp_path, state) from thunder.distributed import _unshard_params # unshard this model's parameters to compare with the original state dict before sharding _unshard_params(model, model.process_group_for_ddp, True) # we loaded rank 0's weights, so this would fail in the other ranks if fabric.global_rank == 0: actual = model.state_dict() # `_unshard_params` doesn't offload buffers at the moment assert actual["buf"].device.type == "cuda" actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) assert state["primitive"] == 123 @_RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.parametrize("jit", (False, True)) @pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder") def test_jit_fsdp_before_setup(jit): import thunder fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit)) fabric.launch() x = torch.randn(1, 1, device=fabric.device) model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) tmodel = thunder.jit(model) fmodel = fabric.setup(tmodel) fmodel(x) assert "all_gather" in thunder.last_traces(tmodel)[-1].python() @_RunIf(min_cuda_gpus=1, thunder=True) def test_strategy_fsdp_setup_already_traced(): import thunder device = torch.device("cuda") x = torch.randn(1, 1, device=device) model = torch.nn.Linear(1, 2, bias=False, device=device) strategy = ThunderFSDPStrategy() tmodel = thunder.jit(model) tmodel(x) with pytest.raises(RuntimeError, match="already called"): strategy.setup_module(tmodel) ================================================ FILE: tests/ext_thunder/test_thunder_networks.py ================================================ """Run thunder tests as part of LitGPT CI""" from litgpt.constants import _THUNDER_AVAILABLE if _THUNDER_AVAILABLE: from thunder.tests.test_networks import * # noqa: F403 else: print("Skipping test_thunder_networks.py (thunder not available)") ================================================ FILE: tests/ext_thunder/test_thunder_pretrain.py ================================================ import os from contextlib import redirect_stdout from io import StringIO from unittest.mock import Mock import torch from torch.utils.data import DataLoader from litgpt import Config from litgpt.args import EvalArgs, TrainArgs from litgpt.constants import _THUNDER_AVAILABLE from litgpt.utils import _RunIf if _THUNDER_AVAILABLE: import extensions.thunder.pretrain as thunder_pretrain @_RunIf(min_cuda_gpus=1, thunder=True) def test_pretrain_thunder(tmp_path, monkeypatch): model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset) monkeypatch.setattr(thunder_pretrain, "get_dataloaders", Mock(return_value=(dataloader, dataloader))) monkeypatch.setattr(thunder_pretrain, "save_hyperparameters", Mock()) out_dir = tmp_path / "out" stdout = StringIO() with redirect_stdout(stdout): thunder_pretrain.setup( devices=1, model_config=model_config, out_dir=out_dir, train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1), optimizer="AdamW", ) out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "model_config.yaml"} assert (out_dir / "logs" / "tensorboard" / "version_0").is_dir() logs = stdout.getvalue() assert logs.count("(step)") == 4 assert logs.count("val loss") == 4 assert "Total parameters: 1,888" in logs ================================================ FILE: tests/ext_thunder/test_unsloth_executor.py ================================================ import pytest import torch from litgpt import GPT, Config from litgpt.model import apply_rope, build_rope_cache from litgpt.utils import _RunIf, chunked_cross_entropy @_RunIf(min_cuda_gpus=1, thunder=True) @pytest.mark.parametrize("reduction", ["none", "mean"]) def test_unsloth_cross_entropy(reduction): import thunder from extensions.thunder.unsloth.executor import unsloth_ex logits = torch.randn(64, 128, device="cuda", requires_grad=True) labels = torch.randint(128, (64,), device="cuda") def foo(logits, labels): # this is the variant supported by unsloth. # if different arguments are used, the implementation would no be lowered to unsloth and instead would get # decomposed return torch.nn.functional.cross_entropy(logits, labels, reduction=reduction, ignore_index=-100) cfoo = thunder.jit(foo, executors=[unsloth_ex]) actual = cfoo(logits, labels) trace_str = str(thunder.last_traces(cfoo)[-1]) assert "unsloth_cross_entropy" in trace_str and "backward" not in trace_str trace_str = str(thunder.last_backward_traces(cfoo)[-1]) assert "unsloth_cross_entropy_backward" in trace_str expected = foo(logits, labels) torch.testing.assert_close(actual, expected) (actual_grad,) = torch.autograd.grad(actual.sum(), logits) trace_str = str(thunder.last_backward_traces(cfoo)[-1]) assert "unsloth_cross_entropy_backward" in trace_str out = foo(logits, labels) assert logits.grad is None (expected_grad,) = torch.autograd.grad(out.sum(), logits) torch.testing.assert_close(actual_grad, expected_grad) @pytest.mark.skip(reason="out of date") @_RunIf(min_cuda_gpus=1, thunder=True) def test_unsloth_rope(): import thunder from extensions.thunder.unsloth.executor import unsloth_ex B, nh, T, hs = 2, 32, 64, 16 cos, sin = build_rope_cache(T, hs, device="cuda") cos = cos.unsqueeze(0) sin = sin.unsqueeze(0) q = torch.rand((B, nh, T, hs), device="cuda", requires_grad=True) def foo(x, cos, sin): return apply_rope(x, cos, sin) cfoo = thunder.jit(foo, executors=[unsloth_ex]) actual = cfoo(q, cos, sin) trace_str = str(thunder.last_traces(cfoo)[-1]) assert "unsloth_apply_rope" in trace_str and "backward" not in trace_str trace_str = str(thunder.last_backward_traces(cfoo)[-1]) assert "unsloth_apply_rope_backward" in trace_str expected = foo(q, cos, sin) torch.testing.assert_close(actual, expected) (actual_grad,) = torch.autograd.grad(actual.sum(), q) (expected_grad,) = torch.autograd.grad(expected.sum(), q) torch.testing.assert_close(actual_grad, expected_grad) @_RunIf(min_cuda_gpus=1, thunder=True) def test_unsloth_swiglu(): import thunder from extensions.thunder.unsloth.executor import ThunderLLaMAMLP, unsloth_ex from litgpt import Config from litgpt.model import LLaMAMLP config = Config.from_name("Llama-2-7b-hf") with torch.device("cuda"): x = torch.randn(2, 16, config.n_embd, requires_grad=True) mlp = LLaMAMLP(config) # monkeypatching was successful assert isinstance(mlp, ThunderLLaMAMLP) cmlp = thunder.jit(mlp, executors=[unsloth_ex]) actual = cmlp(x) trace_str = str(thunder.last_traces(cmlp)[-1]) assert "unsloth_swiglu" in trace_str and "backward" not in trace_str trace_str = str(thunder.last_backward_traces(cmlp)[-1]) assert "unsloth_swiglu_backward" in trace_str expected = mlp(x) torch.testing.assert_close(actual, expected) (actual_grad,) = torch.autograd.grad(actual.sum(), x) (expected_grad,) = torch.autograd.grad(expected.sum(), x) torch.testing.assert_close(actual_grad, expected_grad) @_RunIf(min_cuda_gpus=1, thunder=True) def test_unsloth_gpt(): import thunder from extensions.thunder.unsloth.executor import unsloth_ex def forward_and_loss(model, input_ids, targets): logits = model(input_ids) return chunked_cross_entropy(logits, targets, chunk_size=0) cfn = thunder.jit(forward_and_loss, executors=[unsloth_ex]) device = torch.device("cuda") config = Config( vocab_size=320, padding_multiple=64, n_layer=2, n_head=4, n_embd=64, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=1376, ) with device: model = GPT(config) input_ids = torch.randint(1, 10, (2, 3)) targets = torch.randint(0, 10, (2, 3)) loss = cfn(model, input_ids, targets) assert isinstance(loss, torch.Tensor) fwd = thunder.last_traces(cfn) bwd = thunder.last_backward_traces(cfn) fwd_str, bwd_str = fwd[-1].python(), bwd[-1].python() assert "unsloth_cross_entropy" in fwd_str assert "unsloth_cross_entropy_backward" in bwd_str assert "unsloth_apply_rope" in fwd_str assert "unsloth_apply_rope_backward" in bwd_str assert "unsloth_swiglu" in fwd_str assert "unsloth_swiglu_backward" in bwd_str ================================================ FILE: tests/generate/__init__.py ================================================ ================================================ FILE: tests/generate/test_adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import re import subprocess import sys from contextlib import redirect_stderr, redirect_stdout from io import StringIO from unittest.mock import ANY, Mock, call import pytest import torch import yaml skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.", ) @skip_in_ci_on_macos @pytest.mark.parametrize("version", ("v1", "v2")) def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): if version == "v1": import litgpt.generate.adapter as generate else: import litgpt.generate.adapter_v2 as generate config_path = fake_checkpoint_dir / "model_config.yaml" config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1} config_path.write_text(yaml.dump(config)) monkeypatch.setattr(generate, "lazy_load", Mock()) monkeypatch.setattr(generate.GPT, "load_state_dict", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]]) tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz" monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() generate_mock.return_value = torch.tensor([[3, 2, 1]]) monkeypatch.setattr(generate, "generate", generate_mock) num_samples = 1 out, err = StringIO(), StringIO() with redirect_stdout(out), redirect_stderr(err): generate.main(temperature=2.0, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) assert ( generate_mock.mock_calls == [call(ANY, tensor_like, 101, temperature=2.0, top_k=2, top_p=0.9, eos_id=ANY)] * num_samples ) expected_output = "foo bar baz\n" * num_samples # Allow for the config to be printed before the expected repeated strings. pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) err_value = err.getvalue() expected_parts = [ "'padded_vocab_size': 512", "'n_layer': 2", "'n_head': 4", "'head_size': 2", "'n_embd': 8", ] assert all(part in err_value for part in expected_parts) @pytest.mark.parametrize("version", ("", "_v2")) def test_cli(version): args = ["litgpt", f"generate_adapter{version}", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "For models finetuned with" in output ================================================ FILE: tests/generate/test_main.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import re import subprocess import sys from contextlib import redirect_stderr, redirect_stdout from io import StringIO from unittest import mock from unittest.mock import ANY, Mock, call import pytest import torch import yaml import litgpt.generate.base as generate from litgpt import GPT, Config from litgpt.generate.base import sample skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.", ) @pytest.mark.parametrize( "max_seq_length", (pytest.param(10, marks=pytest.mark.xfail(raises=NotImplementedError, strict=True)), 20 + 5) ) def test_generate(max_seq_length): import lightning as L L.seed_everything(1234) T = 5 input_idx = torch.arange(0, T) config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) model = GPT(config) model.max_seq_length = max_seq_length model.set_kv_cache(batch_size=1) max_new_tokens = 20 multinomial_results = [] def multinomial(*args, **kwargs): out = torch.multinomial(*args, **kwargs, num_samples=1) multinomial_results.append(out) return out with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial): out = generate.generate(model, input_idx, T + max_new_tokens, top_k=1) assert out.size(0) == T + max_new_tokens, (out.size(0), T + max_new_tokens) multinomial_results = torch.hstack(multinomial_results) expected = torch.cat((input_idx, multinomial_results)) assert out.shape == expected.shape, (out.shape, expected.shape) torch.testing.assert_close(out, expected) @skip_in_ci_on_macos def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): config_path = fake_checkpoint_dir / "model_config.yaml" config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1} config_path.write_text(yaml.dump(config)) module_mock = Mock() module_mock.config.block_size = 128 load_mock = Mock() load_mock.return_value = load_mock monkeypatch.setattr(generate, "load_checkpoint", load_mock) tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3]) tokenizer_mock.return_value.decode.return_value = "foo bar baz" monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() generate_mock.return_value = torch.tensor([3, 2, 1]) monkeypatch.setattr(generate, "generate", generate_mock) num_samples = 2 out, err = StringIO(), StringIO() with redirect_stdout(out), redirect_stderr(err): generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) assert ( generate_mock.mock_calls == [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)] * num_samples ) expected_output = "foo bar baz\n" * num_samples # Allow for the config to be printed before the expected repeated strings. pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) err_value = err.getvalue() expected_parts = [ "'padded_vocab_size': 512", "'n_layer': 2", "'n_head': 4", ] assert all(part in err_value for part in expected_parts) def test_cli(): args = ["litgpt", "generate", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Default generation option" in output @pytest.mark.parametrize("temperature", (0.0, 1.0, 0.5)) def test_sample(temperature): # shape: 2x3x5 logits = torch.tensor( [ [[24, 4, 98, 77, 47], [65, 70, 32, 67, 24], [92, 32, 88, 36, 62]], [[85, 79, 57, 68, 50], [89, 46, 72, 45, 32], [68, 96, 68, 24, 36]], ], dtype=torch.float32, ) token = sample(logits, temperature=temperature, top_p=0.8) assert token.shape == (1,) # sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported assert token.tolist() == [0] def test_generate_different_results_with_different_top_p(): config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) model = GPT(config) model.max_seq_length = 50 model.set_kv_cache(batch_size=1) torch.manual_seed(123) input_idx = torch.randint(10, size=(1,)) torch.manual_seed(123) output1 = generate.generate(model, input_idx, 20, top_p=1.0) torch.manual_seed(123) output2 = generate.generate(model, input_idx, 20, top_p=0.1) assert not torch.equal(output1, output2) ================================================ FILE: tests/generate/test_sequentially.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import itertools import subprocess import sys from dataclasses import asdict from pathlib import Path from re import escape import pytest import torch import yaml from lightning import Fabric from litgpt import Config from litgpt.generate.sequentially import ( chunk_sizes, layer_to_device, replace_device, sequential, ) from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf from .utils import find_forward_hooks @pytest.mark.parametrize( ("n_layer", "devices", "expected"), [ (6, 1, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}), (6, 2, {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1}), (6, 3, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), (6, 4, {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 3}), (6, 5, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}), (6, 6, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}), ], ) def test_layer_to_device(n_layer, devices, expected): with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=n_layer) c_sizes = chunk_sizes(n_layer, devices) actual = layer_to_device(model, Block, chunk_sizes=c_sizes) expected = {f"transformer.h.{i}": v for i, v in expected.items()} assert actual == expected def path_to_device(model): return {k: str(v.device) for k, v in itertools.chain(model.named_parameters(), model.named_buffers())} def test_replace_device(): class Submodule(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("foo", torch.tensor(1, device="cpu")) self.register_buffer("bar", torch.tensor(1, device="cpu")) class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.modules = torch.nn.ModuleDict( { "module1": torch.nn.Linear(1, 1, bias=True, device="meta"), "module2": torch.nn.Linear(1, 1, bias=False, device="cpu"), } ) self.submodule = Submodule() model = MyModel() assert path_to_device(model) == { "modules.module1.bias": "meta", "modules.module1.weight": "meta", "modules.module2.weight": "cpu", "submodule.bar": "cpu", "submodule.foo": "cpu", } model = replace_device(model, torch.device("cpu"), torch.device("meta")) assert path_to_device(model) == { "modules.module1.bias": "meta", "modules.module1.weight": "meta", "modules.module2.weight": "meta", "submodule.bar": "meta", "submodule.foo": "meta", } model = MyModel() model.submodule.bar = model.submodule.bar.to("meta") with pytest.raises( ValueError, match=escape("multiple devices: {'submodule.foo': device(type='cpu'), 'submodule.bar': device(type='meta')}"), ): replace_device(model, torch.device("cpu"), torch.device("meta")) def _test_model_1device(accelerator): fabric = Fabric(accelerator=accelerator, devices=1) with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=2) model = sequential(model, fabric.device, 15, 1) device_str = str(fabric.device) assert path_to_device(model) == { "cos": device_str, "sin": device_str, "lm_head.weight": device_str, "transformer.h.0.attn.qkv.bias": device_str, "transformer.h.0.attn.qkv.weight": device_str, "transformer.h.0.attn.proj.bias": device_str, "transformer.h.0.attn.proj.weight": device_str, "transformer.h.0.mlp.fc.bias": device_str, "transformer.h.0.mlp.fc.weight": device_str, "transformer.h.0.mlp.proj.bias": device_str, "transformer.h.0.mlp.proj.weight": device_str, "transformer.h.0.norm_1.bias": device_str, "transformer.h.0.norm_1.weight": device_str, "transformer.h.0.norm_2.bias": device_str, "transformer.h.0.norm_2.weight": device_str, "transformer.h.0.attn.kv_cache.k": device_str, "transformer.h.0.attn.kv_cache.v": device_str, "transformer.h.1.attn.qkv.bias": device_str, "transformer.h.1.attn.qkv.weight": device_str, "transformer.h.1.attn.proj.bias": device_str, "transformer.h.1.attn.proj.weight": device_str, "transformer.h.1.mlp.fc.bias": device_str, "transformer.h.1.mlp.fc.weight": device_str, "transformer.h.1.mlp.proj.bias": device_str, "transformer.h.1.mlp.proj.weight": device_str, "transformer.h.1.norm_1.bias": device_str, "transformer.h.1.norm_1.weight": device_str, "transformer.h.1.norm_2.bias": device_str, "transformer.h.1.norm_2.weight": device_str, "transformer.h.1.attn.kv_cache.k": device_str, "transformer.h.1.attn.kv_cache.v": device_str, "transformer.ln_f.bias": device_str, "transformer.ln_f.weight": device_str, "transformer.wte.weight": device_str, } assert model.max_seq_length == 15 @_RunIf(min_cuda_gpus=1) def test_model_1device_cuda(): _test_model_1device("cuda") def test_model_1device_cpu(): _test_model_1device("cpu") @_RunIf(min_cuda_gpus=2) def test_model_forward_hooks(): fabric = Fabric(accelerator="cuda", devices=1) with torch.device("meta"): model = GPT.from_name("pythia-14m") # 6 layers model = sequential(model, fabric.device, max_seq_length=15, devices=2) hooks = find_forward_hooks(model) actual = path_to_device(model) assert actual == { "lm_head.weight": "cuda:0", "transformer.wte.weight": "cuda:0", "transformer.h.0.norm_1.weight": "cuda:0", "transformer.h.0.norm_1.bias": "cuda:0", "transformer.h.0.attn.qkv.weight": "cuda:0", "transformer.h.0.attn.qkv.bias": "cuda:0", "transformer.h.0.attn.proj.weight": "cuda:0", "transformer.h.0.attn.proj.bias": "cuda:0", "transformer.h.0.norm_2.weight": "cuda:0", "transformer.h.0.norm_2.bias": "cuda:0", "transformer.h.0.mlp.fc.weight": "cuda:0", "transformer.h.0.mlp.fc.bias": "cuda:0", "transformer.h.0.mlp.proj.weight": "cuda:0", "transformer.h.0.mlp.proj.bias": "cuda:0", "transformer.h.1.norm_1.weight": "cuda:0", "transformer.h.1.norm_1.bias": "cuda:0", "transformer.h.1.attn.qkv.weight": "cuda:0", "transformer.h.1.attn.qkv.bias": "cuda:0", "transformer.h.1.attn.proj.weight": "cuda:0", "transformer.h.1.attn.proj.bias": "cuda:0", "transformer.h.1.norm_2.weight": "cuda:0", "transformer.h.1.norm_2.bias": "cuda:0", "transformer.h.1.mlp.fc.weight": "cuda:0", "transformer.h.1.mlp.fc.bias": "cuda:0", "transformer.h.1.mlp.proj.weight": "cuda:0", "transformer.h.1.mlp.proj.bias": "cuda:0", "transformer.h.2.norm_1.weight": "cuda:0", "transformer.h.2.norm_1.bias": "cuda:0", "transformer.h.2.attn.qkv.weight": "cuda:0", "transformer.h.2.attn.qkv.bias": "cuda:0", "transformer.h.2.attn.proj.weight": "cuda:0", "transformer.h.2.attn.proj.bias": "cuda:0", "transformer.h.2.norm_2.weight": "cuda:0", "transformer.h.2.norm_2.bias": "cuda:0", "transformer.h.2.mlp.fc.weight": "cuda:0", "transformer.h.2.mlp.fc.bias": "cuda:0", "transformer.h.2.mlp.proj.weight": "cuda:0", "transformer.h.2.mlp.proj.bias": "cuda:0", "transformer.h.3.norm_1.weight": "cuda:1", "transformer.h.3.norm_1.bias": "cuda:1", "transformer.h.3.attn.qkv.weight": "cuda:1", "transformer.h.3.attn.qkv.bias": "cuda:1", "transformer.h.3.attn.proj.weight": "cuda:1", "transformer.h.3.attn.proj.bias": "cuda:1", "transformer.h.3.norm_2.weight": "cuda:1", "transformer.h.3.norm_2.bias": "cuda:1", "transformer.h.3.mlp.fc.weight": "cuda:1", "transformer.h.3.mlp.fc.bias": "cuda:1", "transformer.h.3.mlp.proj.weight": "cuda:1", "transformer.h.3.mlp.proj.bias": "cuda:1", "transformer.h.4.norm_1.weight": "cuda:1", "transformer.h.4.norm_1.bias": "cuda:1", "transformer.h.4.attn.qkv.weight": "cuda:1", "transformer.h.4.attn.qkv.bias": "cuda:1", "transformer.h.4.attn.proj.weight": "cuda:1", "transformer.h.4.attn.proj.bias": "cuda:1", "transformer.h.4.norm_2.weight": "cuda:1", "transformer.h.4.norm_2.bias": "cuda:1", "transformer.h.4.mlp.fc.weight": "cuda:1", "transformer.h.4.mlp.fc.bias": "cuda:1", "transformer.h.4.mlp.proj.weight": "cuda:1", "transformer.h.4.mlp.proj.bias": "cuda:1", "transformer.h.5.norm_1.weight": "cuda:1", "transformer.h.5.norm_1.bias": "cuda:1", "transformer.h.5.attn.qkv.weight": "cuda:1", "transformer.h.5.attn.qkv.bias": "cuda:1", "transformer.h.5.attn.proj.weight": "cuda:1", "transformer.h.5.attn.proj.bias": "cuda:1", "transformer.h.5.norm_2.weight": "cuda:1", "transformer.h.5.norm_2.bias": "cuda:1", "transformer.h.5.mlp.fc.weight": "cuda:1", "transformer.h.5.mlp.fc.bias": "cuda:1", "transformer.h.5.mlp.proj.weight": "cuda:1", "transformer.h.5.mlp.proj.bias": "cuda:1", "transformer.ln_f.weight": "cuda:0", "transformer.ln_f.bias": "cuda:0", "cos": "cuda:0", "sin": "cuda:0", "transformer.h.0.attn.kv_cache.k": "cuda:0", "transformer.h.0.attn.kv_cache.v": "cuda:0", "transformer.h.1.attn.kv_cache.k": "cuda:0", "transformer.h.1.attn.kv_cache.v": "cuda:0", "transformer.h.2.attn.kv_cache.k": "cuda:0", "transformer.h.2.attn.kv_cache.v": "cuda:0", "transformer.h.3.attn.kv_cache.k": "cuda:1", "transformer.h.3.attn.kv_cache.v": "cuda:1", "transformer.h.4.attn.kv_cache.k": "cuda:1", "transformer.h.4.attn.kv_cache.v": "cuda:1", "transformer.h.5.attn.kv_cache.k": "cuda:1", "transformer.h.5.attn.kv_cache.v": "cuda:1", } assert hooks == { "transformer.h.3": [("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {})], "transformer.h.4": [("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {})], "transformer.h.5": [ ("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {}), ("forward_hook", "move_block_output", (torch.device(type="cuda", index=0),), {}), ], } root = Path(__file__).parent.parent.resolve() @_RunIf(min_cuda_gpus=2) @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_base_with_sequentially(tmp_path): # download the tokenizer download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) checkpoint_dir = tmp_path / "EleutherAI/pythia-14m" # save the config config = Config.from_name("pythia-14m") (checkpoint_dir / "model_config.yaml").write_text(yaml.dump(asdict(config))) # create a state dict to load from torch.save(GPT(config).state_dict(), checkpoint_dir / "lit_model.pth") args = [ str(checkpoint_dir), "--num_samples=1", "--max_new_tokens=10", "--precision=16-true", "--temperature=0.0", ] env = {"CUDA_VISIBLE_DEVICES": "0,1"} sequential_stdout = subprocess.check_output( [sys.executable, "-m", "litgpt", "generate_sequentially", *args], env=env, cwd=root, ).decode() assert "What food do llamas eat?" in sequential_stdout def test_cli(): args = ["litgpt", "generate_sequentially", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generation script that partitions layers across" in output ================================================ FILE: tests/generate/test_tp.py ================================================ import subprocess import sys from dataclasses import asdict, replace from pathlib import Path from unittest.mock import Mock import pytest import torch import yaml from litgpt import GPT, Config from litgpt.generate.tp import tensor_parallel, tensor_parallel_linear from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf from .utils import find_forward_hooks def test_tensor_parallel_linear(): fabric = Mock() fabric.world_size = 4 fabric.global_rank = 2 def get_linear(bias=True): linear = torch.nn.Linear(8, 8, bias=bias) linear.weight.data = torch.arange(64, dtype=torch.float32).reshape(8, 8) if bias: linear.bias.data = torch.arange(8, dtype=torch.float32) return linear linear = get_linear() tensor_parallel_linear(fabric, linear, "colwise") expected = torch.arange(32, 48, dtype=torch.float32).reshape(2, 8) torch.testing.assert_close(linear.weight, expected) expected = torch.arange(4, 6, dtype=torch.float32) torch.testing.assert_close(linear.bias, expected) linear = get_linear(bias=False) tensor_parallel_linear(fabric, linear, "rowwise") expected = torch.arange(4, 62, 8, dtype=torch.float32).reshape(8, 1) expected = torch.cat([expected, expected + 1], dim=1) torch.testing.assert_close(linear.weight, expected) assert linear.bias is None @pytest.mark.parametrize( ("name", "expected"), [ ( "Llama-2-70b-hf", { "transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.0.mlp": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.mlp": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.mlp": [("forward_hook", "all_reduce_output", (8,), {})], }, ), ( "falcon-180B", { "transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.0.mlp": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.mlp": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.mlp": [("forward_hook", "all_reduce_output", (8,), {})], }, ), ( "Mixtral-8x7B-v0.1", { "transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.0.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.0.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.1.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})], "transformer.h.2.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})], }, ), ], ) def test_tensor_parallel_llama(name, expected): fabric = Mock() fabric.world_size = 8 fabric.global_rank = 1 with torch.device("meta"): model = GPT.from_name(name, n_layer=3, n_expert=2) config = replace(model.config) # make a copy model = tensor_parallel(fabric, model) hooks = find_forward_hooks(model) assert hooks == expected assert model.config.n_embd * 8 == config.n_embd assert model.config.n_head * 8 == config.n_head assert model.config.n_query_groups * 8 == config.n_query_groups root = Path(__file__).parent.parent.resolve() @_RunIf(min_cuda_gpus=2) def test_tp(tmp_path): # download the tokenizer download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) checkpoint_dir = tmp_path / "EleutherAI/pythia-14m" # save the config config = Config.from_name("pythia-14m") (checkpoint_dir / "model_config.yaml").write_text(yaml.dump(asdict(config))) # create a state dict to load from torch.save(GPT(config).state_dict(), checkpoint_dir / "lit_model.pth") args = [ str(checkpoint_dir), "--num_samples=1", "--max_new_tokens=10", "--precision=16-true", "--temperature=0.0", ] env = {"CUDA_VISIBLE_DEVICES": "0,1"} tp_stdout = subprocess.check_output( [sys.executable, "-m", "litgpt", "generate_tp", *args], env=env, cwd=root ).decode() # there is some unaccounted randomness so cannot compare the output with that of `generate/base.py` assert "What food do llamas eat?" in tp_stdout def test_cli(): args = ["litgpt", "generate_tp", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Generation script that uses tensor parallelism" in output ================================================ FILE: tests/generate/utils.py ================================================ from collections import defaultdict def find_forward_hooks(module): mapping = defaultdict(list) for name, submodule in module.named_modules(): for hook in submodule._forward_pre_hooks.values(): hook_data = ("forward_pre_hook", hook.func.__name__, hook.args, hook.keywords) mapping[name].append(hook_data) for hook in submodule._forward_hooks.values(): hook_data = ("forward_hook", hook.func.__name__, hook.args, hook.keywords) mapping[name].append(hook_data) return dict(mapping) ================================================ FILE: tests/test_adapter.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout from copy import deepcopy from dataclasses import asdict from io import StringIO from unittest import mock from unittest.mock import Mock import pytest import torch import yaml from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig import litgpt.adapter as gpt_adapter import litgpt.finetune.adapter as module import litgpt.model as gpt from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf def test_config_identical(): name = "pythia-14m" base_config = asdict(gpt.Config.from_name(name)) adapter_config = asdict(gpt_adapter.Config.from_name(name)) del adapter_config["adapter_prompt_length"] del adapter_config["adapter_start_layer"] assert adapter_config == base_config with Fabric(accelerator="cpu").init_module(empty_init=True): base_model = gpt.GPT.from_name(name) adapter_model = gpt_adapter.GPT.from_name(name) assert adapter_model.lm_head.weight.shape == base_model.lm_head.weight.shape def test_adapter_filter(tmp_path): fabric = Fabric(devices=1) model = GPT.from_name("pythia-14m", n_layer=4) save_path = tmp_path / "model.pth" fabric.save(save_path, {"model": model}, filter={"model": adapter_filter}) saved = torch.load(save_path)["model"] expected = { "transformer.h.2.attn.adapter_wte.weight", "transformer.h.2.attn.gating_factor", "transformer.h.3.attn.adapter_wte.weight", "transformer.h.3.attn.gating_factor", } assert set(saved) == expected @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) monkeypatch.setattr(module, "load_checkpoint", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) out_dir = tmp_path / "out" stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), out_dir=out_dir, precision="32-true", train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { "lit_model.pth.adapter", "model_config.yaml", "tokenizer_config.json", "tokenizer.json", "hyperparameters.yaml", "prompt_style.yaml", } assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() logs = stdout.getvalue() assert logs.count("(step)") == 6 assert logs.count("val loss") == 4 # 3 validations + 1 final validation assert logs.count("Final evaluation") == 1 assert "of trainable parameters: 168" in logs def test_adapter_gpt_init_weights(): config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0) model = GPT(config) param = model.transformer.h[0].attn.gating_factor assert (param == 0).all() torch.nn.init.constant_(param, 1.23) assert (param != 0).any() model.apply(model._init_weights) assert (param == 0).all() @_RunIf(dynamo=True) @torch.inference_mode() def test_adapter_compile(): model = GPT.from_name("pythia-14m", n_layer=3) x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 model = GPT(model.config) model.set_kv_cache(2) input_pos = torch.arange(model.config.block_size) explanation = torch._dynamo.explain(model)(x, input_pos) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 @_RunIf(min_cuda_gpus=1) def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW model_config = dict( block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True ) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) monkeypatch.setattr(module, "load_checkpoint", Mock()) train_mock = Mock() train_mock.return_value = { "raw_tokens": 1000, "raw_tokens_plus_prompt_template": 1100, "raw_tokens_plus_prompt_template_and_padding": 1200, } monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), precision="16-true", quantize="bnb.nf4-dq", out_dir=tmp_path, ) _, kwargs = train_mock.call_args fabric = kwargs["fabric"] model = kwargs["model"] optimizer = kwargs["optimizer"] assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) assert isinstance(optimizer, _FabricOptimizer) assert isinstance(optimizer._optimizer, PagedAdamW) dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} for name, layer in model.named_parameters(): name = name[len("_forward_module.") :] dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.gating_factor", "transformer.h.0.attn.qkv.bias", "transformer.h.0.attn.proj.bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.h.0.norm_2.weight", "transformer.h.0.norm_2.bias", "transformer.h.0.mlp.fc.bias", "transformer.h.0.mlp.proj.bias", "transformer.h.1.norm_1.weight", "transformer.h.1.norm_1.bias", "transformer.h.1.attn.gating_factor", "transformer.h.1.attn.qkv.bias", "transformer.h.1.attn.proj.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.1.mlp.fc.bias", "transformer.h.1.mlp.proj.bias", "transformer.ln_f.weight", "transformer.ln_f.bias", }, "torch.uint8": { "lm_head.weight", "transformer.h.0.attn.qkv.weight", "transformer.h.0.attn.proj.weight", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.proj.weight", "transformer.h.1.attn.qkv.weight", "transformer.h.1.attn.proj.weight", "transformer.h.1.mlp.fc.weight", "transformer.h.1.mlp.proj.weight", }, } assert {p.name for p in tmp_path.rglob("*.pth.adapter")} == {"lit_model.pth.adapter"} state_dict = torch.load(tmp_path / "final" / "lit_model.pth.adapter") assert len(state_dict) == 1 dtype_to_name = {"torch.float16": set()} for name, layer in state_dict["model"].items(): dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { "transformer.h.0.attn.adapter_wte.weight", "transformer.h.0.attn.gating_factor", "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.attn.gating_factor", } } logs = stdout.getvalue() assert "of trainable parameters: 168" in logs assert "of non-trainable parameters: 1,888" in logs @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) def test_against_hf_gemma(model_name): device = torch.device("cpu") dtype = torch.float32 T = 5 ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_2(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Gemma2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Gemma3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_3({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5) def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, ) attention_1 = CausalSelfAttention(config=config, block_idx=0) # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` # and make them interleaved state_dict = deepcopy(attention_1.state_dict()) state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) ================================================ FILE: tests/test_adapter_v2.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout from copy import deepcopy from io import StringIO from unittest import mock from unittest.mock import Mock import pytest import torch import yaml from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import litgpt.config as config_module import litgpt.finetune.adapter_v2 as module from litgpt.adapter_v2 import GPT as AdapterV2GPT from litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf def test_config_identical(): name = "pythia-14m" with Fabric(accelerator="cpu").init_module(empty_init=True): base_model = BaseGPT.from_name(name) adapter_model = AdapterV2GPT.from_name(name) assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias") assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale") assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias") assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale") def test_adapter_v2_filter(tmp_path): fabric = Fabric(devices=1) model = AdapterV2GPT.from_name("pythia-14m", n_layer=3) save_path = tmp_path / "model.pth" fabric.save(save_path, {"model": model}, filter={"model": adapter_filter}) saved = torch.load(save_path)["model"] expected = { "lm_head.adapter_bias", "lm_head.adapter_scale", "transformer.ln_f.bias", "transformer.ln_f.weight", "transformer.h.2.attn.adapter_wte.weight", "transformer.h.2.attn.gating_factor", } for layer in range(3): for param in ( "attn.qkv.adapter_bias", "attn.qkv.adapter_scale", "attn.proj.adapter_bias", "attn.proj.adapter_scale", "mlp.fc.adapter_bias", "mlp.fc.adapter_scale", "mlp.proj.adapter_bias", "mlp.proj.adapter_scale", "norm_1.bias", "norm_1.weight", "norm_2.bias", "norm_2.weight", ): expected.add(f"transformer.h.{layer}.{param}") assert set(saved) == expected @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) monkeypatch.setattr(module, "load_checkpoint", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) out_dir = tmp_path / "out" stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), out_dir=out_dir, precision="32-true", train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { "lit_model.pth.adapter_v2", "model_config.yaml", "tokenizer_config.json", "tokenizer.json", "hyperparameters.yaml", "prompt_style.yaml", } assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() logs = stdout.getvalue() assert logs.count("(step)") == 6 assert logs.count("val loss") == 4 # 3 validations + 1 final validation assert logs.count("Final evaluation") == 1 assert "of trainable parameters: 552" in logs def test_adapter_v2_gpt_init_weights(): config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0) model = AdapterV2GPT(config) for param in (model.transformer.h[0].attn.gating_factor, model.lm_head.adapter_bias): assert (param == 0).all() torch.nn.init.constant_(param, 1.23) assert (param != 0).any() model.apply(model._init_weights) assert (param == 0).all() @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_adapter_v2_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0) keys = lora_model.load_state_dict(base_model_state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert adapter_filter(k, None) @_RunIf(dynamo=True) @torch.inference_mode() def test_adapter_v2_compile(): model = AdapterV2GPT.from_name("pythia-14m", n_layer=3) x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 model = AdapterV2GPT(model.config) model.set_kv_cache(2) input_pos = torch.arange(model.config.block_size) explanation = torch._dynamo.explain(model)(x, input_pos) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 @torch.inference_mode() def test_against_hf_mixtral(): device = torch.device("cpu") dtype = torch.float32 ours_config = Config.from_name( "Mixtral-8x7B-Instruct-v0.1", padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, n_expert=4, ) T = 5 theirs_config = MixtralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, num_local_experts=ours_config.n_expert, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = MixtralForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) # strict=False because missing keys due to adapter weights not contained in state dict ours_model.load_state_dict(state_dict, strict=False) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) def test_against_hf_gemma(model_name): device = torch.device("cpu") dtype = torch.float32 T = 5 ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert adapter_filter(k, None) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b")) def test_against_original_gemma_2(model_name): device = torch.device("cpu") dtype = torch.float32 T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Gemma2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert adapter_filter(k, None) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close( # some macOS devices have numerical differences, hence the tol bump ours_y, theirs_y, atol=1e-4, rtol=1e-5, ) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it")) def test_against_original_gemma_3(model_name): device = torch.device("cpu") dtype = torch.float32 T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, rope_scaling={"factor": 8.0, "rope_type": "linear"}, rope_local_base_freq=ours_config.rope_local_base_freq, ) theirs_model = Gemma3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_3({}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert adapter_filter(k, None) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close( ours_y, theirs_y, rtol=3e-5, atol=3e-5 ) # some macOS devices have numerical differences, hence the tol bump @_RunIf(min_cuda_gpus=1) def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW model_config = dict( block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True ) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) monkeypatch.setattr(module, "load_checkpoint", Mock()) train_mock = Mock() train_mock.return_value = { "raw_tokens": 1000, "raw_tokens_plus_prompt_template": 1100, "raw_tokens_plus_prompt_template_and_padding": 1200, } monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), precision="16-true", quantize="bnb.nf4-dq", out_dir=tmp_path, ) _, kwargs = train_mock.call_args fabric = kwargs["fabric"] model = kwargs["model"] optimizer = kwargs["optimizer"] assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) assert isinstance(optimizer, _FabricOptimizer) assert isinstance(optimizer._optimizer, PagedAdamW) dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} for name, layer in model.named_parameters(): name = name[len("_forward_module.") :] dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.uint8": { "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", "transformer.h.1.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", "transformer.h.0.attn.qkv.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.1.mlp.proj.adapter_bias", "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.qkv.linear.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.mlp.proj.linear.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.h.0.attn.proj.adapter_scale", "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.1.norm_2.bias", "transformer.h.1.attn.proj.adapter_scale", "transformer.h.0.norm_2.bias", "transformer.h.0.mlp.fc.adapter_scale", "transformer.h.0.attn.proj.linear.bias", "transformer.h.1.attn.proj.linear.bias", "transformer.h.1.norm_1.bias", "transformer.h.0.norm_1.weight", "transformer.h.1.attn.proj.adapter_bias", "transformer.h.0.mlp.proj.adapter_scale", "transformer.h.0.mlp.proj.adapter_bias", "transformer.h.1.mlp.fc.adapter_bias", "transformer.h.1.mlp.proj.adapter_scale", "transformer.h.1.attn.gating_factor", "transformer.h.1.norm_1.weight", "transformer.ln_f.bias", "transformer.h.0.mlp.fc.adapter_bias", "lm_head.adapter_scale", "lm_head.adapter_bias", "transformer.h.1.norm_2.weight", "transformer.h.0.attn.adapter_wte.weight", "transformer.h.1.attn.qkv.adapter_scale", "transformer.h.1.mlp.fc.adapter_scale", "transformer.h.1.attn.qkv.linear.bias", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.0.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.attn.gating_factor", "transformer.h.0.attn.proj.adapter_bias", }, } assert {p.name for p in tmp_path.rglob("*.pth.adapter_v2")} == {"lit_model.pth.adapter_v2"} state_dict = torch.load(tmp_path / "final" / "lit_model.pth.adapter_v2") assert len(state_dict) == 1 dtype_to_name = {"torch.float16": set()} for name, layer in state_dict["model"].items(): dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.attn.proj.adapter_bias", "transformer.h.1.mlp.fc.adapter_scale", "lm_head.adapter_bias", "transformer.h.0.mlp.proj.adapter_scale", "transformer.ln_f.bias", "lm_head.adapter_scale", "transformer.h.1.norm_2.weight", "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.0.mlp.proj.adapter_bias", "transformer.h.0.attn.gating_factor", "transformer.h.1.norm_1.bias", "transformer.h.1.mlp.fc.adapter_bias", "transformer.h.1.mlp.proj.adapter_scale", "transformer.h.0.mlp.fc.adapter_scale", "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.0.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.attn.proj.adapter_scale", "transformer.h.1.mlp.proj.adapter_bias", "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.1.attn.gating_factor", "transformer.h.0.mlp.fc.adapter_bias", "transformer.h.1.attn.proj.adapter_scale", "transformer.h.0.attn.proj.adapter_bias", "transformer.h.0.norm_1.bias", "transformer.h.0.norm_2.bias", "transformer.h.1.norm_1.weight", "transformer.h.1.attn.qkv.adapter_scale", } } logs = stdout.getvalue() assert "of trainable parameters: 552" in logs assert "of non-trainable parameters: 1,808" in logs def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, ) attention_1 = CausalSelfAttention(config=config, block_idx=0) # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` # and make them interleaved state_dict = deepcopy(attention_1.state_dict()) state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) ================================================ FILE: tests/test_api.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import re import sys from collections import OrderedDict from pathlib import Path from unittest.mock import MagicMock, patch import pytest import torch from lightning.fabric.accelerators import CUDAAccelerator from litgpt.api import LLM, benchmark_dict_to_markdown_table, calculate_number_of_devices from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.", ) if sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true": USE_MPS = False elif torch.backends.mps.is_available(): USE_MPS = True else: USE_MPS = False @pytest.fixture def mock_llm(): llm = MagicMock(spec=LLM) llm.model = MagicMock() llm.preprocessor = MagicMock() llm.prompt_style = MagicMock() llm.checkpoint_dir = MagicMock() llm.fabric = MagicMock() return llm def test_load_model(mock_llm): assert isinstance(mock_llm, LLM) assert mock_llm.model is not None assert mock_llm.preprocessor is not None assert mock_llm.prompt_style is not None assert mock_llm.checkpoint_dir is not None assert mock_llm.fabric is not None def test_generate(mock_llm): prompt = "What do Llamas eat?" mock_llm.generate.return_value = prompt + " Mock output" output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5) assert isinstance(output, str) assert len(output) > len(prompt) def test_stream_generate(mock_llm): prompt = "What do Llamas eat?" def iterator(): outputs = (prompt + " Mock output").split() yield from outputs mock_llm.generate.return_value = iterator() output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5, stream=True) result = "".join([out for out in output]) assert len(result) > len(prompt) def test_generate_token_ids(mock_llm): prompt = "What do Llamas eat?" mock_output_ids = MagicMock(spec=torch.Tensor) mock_output_ids.shape = [len(prompt) + 10] mock_llm.generate.return_value = mock_output_ids output_ids = mock_llm.generate(prompt, max_new_tokens=10, return_as_token_ids=True) assert isinstance(output_ids, torch.Tensor) assert output_ids.shape[0] > len(prompt) def test_calculate_number_of_devices(): assert calculate_number_of_devices(1) == 1 assert calculate_number_of_devices([0, 1, 2]) == 3 assert calculate_number_of_devices(None) == 0 def test_llm_load_random_init(tmp_path): download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) torch.manual_seed(123) with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load(model="pythia-160m", init="random", tokenizer_dir=Path(tmp_path / "EleutherAI/pythia-14m")) input_text = "some text text" output_text = llm.generate(input_text, max_new_tokens=15) ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) assert ln <= 15 # The following below tests that generate works with different prompt lengths # after the kv cache was set input_text = "some text" output_text = llm.generate(input_text, max_new_tokens=15) ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) assert ln <= 15 input_text = "some text text text" output_text = llm.generate(input_text, max_new_tokens=15) ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) assert ln <= 15 def test_llm_load_hub_init(tmp_path): torch.manual_seed(123) with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained") text_1 = llm.generate("text", max_new_tokens=10, top_k=1) assert len(text_1) > 0 text_2 = llm.generate("text", max_new_tokens=10, top_k=1, stream=True) text_2 = "".join(list(text_2)) assert text_1 == text_2, (text_1, text_2) def test_model_not_initialized(tmp_path): llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained", distribute=None) s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model." with pytest.raises(AttributeError, match=re.escape(s)): llm.generate("text") llm = LLM.load(model="EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random", distribute=None) s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model." with pytest.raises(AttributeError, match=re.escape(s)): llm.generate("text") @_RunIf(min_cuda_gpus=2) def test_more_than_1_device_for_sequential_gpu(tmp_path): device_count = CUDAAccelerator.auto_device_count() if device_count <= 2: model_name = "EleutherAI/pythia-14m" else: model_name = "EleutherAI/pythia-160m" with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model=model_name, ) with pytest.raises( NotImplementedError, match="Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'.", ): llm.distribute(devices=2) llm.distribute(devices=2, generate_strategy="sequential") assert isinstance(llm.generate("What do llamas eat?"), str) assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0" last_layer_idx = len(llm.model.transformer.h) - 1 assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == "cuda:1" # Also check with default (devices="auto") setting llm.distribute(generate_strategy="sequential") assert isinstance(llm.generate("What do llamas eat?"), str) assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0" assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == f"cuda:{device_count - 1}" @_RunIf(min_cuda_gpus=2) @pytest.mark.skipif(bool(os.getenv("SKIP_WITH_CI")), reason="Skip this test in CI due to ...") def test_more_than_1_device_for_tensor_parallel_gpu(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load(model="EleutherAI/pythia-14m") # this crashes the CI, maybe because of process forking; works fine locally though llm.distribute(devices=2, generate_strategy="tensor_parallel") assert isinstance(llm.generate("What do llamas eat?"), str) @_RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("strategy", ("sequential", "tensor_parallel")) @pytest.mark.xfail( NotADirectoryError, reason="This test is expected to fail due to a NotADirectoryError.", strict=False ) def test_sequential_tp_incompatibility_with_random_weights(strategy, tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load(model="EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random") with pytest.raises( NotImplementedError, match=re.escape( "The LLM was initialized with init='random' but .distribute() currently only supports pretrained weights." ), ): llm.distribute(devices=1, generate_strategy=strategy) @pytest.mark.parametrize("strategy", ("sequential", "tensor_parallel")) def test_sequential_tp_cpu(strategy, tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", distribute=None, ) with pytest.raises( NotImplementedError, match=f"generate_strategy='{strategy}' is only supported for accelerator='cuda'|'gpu'." ): llm.distribute(devices=1, accelerator="cpu", generate_strategy=strategy) def test_initialization_for_trainer(tmp_path): llm = LLM.load(model="EleutherAI/pythia-14m", distribute=None) s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model." with pytest.raises(AttributeError, match=re.escape(s)): llm.generate("hello world") llm.trainer_setup() llm.model.to(llm.preprocessor.device) assert isinstance(llm.generate("hello world"), str) @_RunIf(min_cuda_gpus=1) def test_quantization_is_applied(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) llm.distribute(devices=1, quantize="bnb.nf4", precision="bf16-true") strtype = str(type(llm.model.lm_head)) assert "NF4Linear" in strtype, strtype @_RunIf(min_cuda_gpus=1) def test_fixed_kv_cache(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) llm.distribute(devices=1, fixed_kv_cache_size=100) # Request too many tokens with pytest.raises(NotImplementedError, match="max_seq_length 512 needs to be >= 9223372036854775809"): _ = llm.generate("hello world", max_new_tokens=2**63) def test_invalid_accelerator(tmp_path): llm = LLM.load(model="EleutherAI/pythia-14m", distribute=None) with pytest.raises(ValueError, match="Invalid accelerator"): llm.distribute(accelerator="invalid") def test_returned_benchmark_dir(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) text, bench_d = llm.benchmark(prompt="hello world") assert isinstance(bench_d["Inference speed in tokens/sec"], list) assert len(bench_d["Inference speed in tokens/sec"]) == 1 assert isinstance(bench_d["Inference speed in tokens/sec"][0], float) text, bench_d = llm.benchmark(prompt="hello world", stream=True) assert isinstance(bench_d["Inference speed in tokens/sec"], list) assert len(bench_d["Inference speed in tokens/sec"]) == 1 assert isinstance(bench_d["Inference speed in tokens/sec"][0], float) text, bench_d = llm.benchmark(num_iterations=10, prompt="hello world", stream=True) assert isinstance(bench_d["Inference speed in tokens/sec"], list) assert len(bench_d["Inference speed in tokens/sec"]) == 10 assert isinstance(bench_d["Inference speed in tokens/sec"][0], float) def test_benchmark_dict_to_markdown_table_single_values(): bench_d = { "Inference speed in tokens/sec": [17.617540650112936], "Seconds to first token": [0.6533610639999097], "Seconds total": [1.4758019020000575], "Tokens generated": [26], "Total GPU memory allocated in GB": [5.923729408], } expected_output = ( "| Metric | Mean | Std Dev |\n" "|-------------------------------------|-----------------------------|-----------------------------|\n" "| Inference speed in tokens/sec | 17.62 | nan |\n" "| Seconds to first token | 0.65 | nan |\n" "| Seconds total | 1.48 | nan |\n" "| Tokens generated | 26.00 | nan |\n" "| Total GPU memory allocated in GB | 5.92 | nan |\n" ) assert benchmark_dict_to_markdown_table(bench_d) == expected_output def test_benchmark_dict_to_markdown_table_multiple_values(): bench_d_list = { "Inference speed in tokens/sec": [ 17.034547562152305, 32.8974175404589, 33.04784205046782, 32.445697744648584, 33.204480197756396, 32.64187570945661, 33.21232058140845, 32.69377798373551, 32.92351459309756, 32.48909032591177, ], "Seconds to first token": [ 0.7403525039999295, 0.022901020000063, 0.02335712100011733, 0.022969672000272112, 0.022788318000039, 0.02365505999978268, 0.02320190000000366, 0.022791139999753796, 0.022871761999795126, 0.023060415999680117, ], "Seconds total": [ 1.5263099829999192, 0.7903355929997815, 0.7867382069998712, 0.8013389080001616, 0.7830268640000213, 0.7965228539997042, 0.7828420160003589, 0.7952583520000189, 0.7897091279996857, 0.8002686360000553, ], "Tokens generated": [26, 26, 26, 26, 26, 26, 26, 26, 26, 26], "Total GPU memory allocated in GB": [ 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, 5.923729408, ], } expected_output = ( "| Metric | Mean | Std Dev |\n" "|-------------------------------------|-----------------------------|-----------------------------|\n" "| Inference speed in tokens/sec | 31.26 | 5.01 |\n" "| Seconds to first token | 0.09 | 0.23 |\n" "| Seconds total | 0.87 | 0.23 |\n" "| Tokens generated | 26.00 | 0.00 |\n" "| Total GPU memory allocated in GB | 5.92 | 0.00 |\n" ) assert benchmark_dict_to_markdown_table(bench_d_list) == expected_output def test_state_dict(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) assert isinstance(llm.state_dict(), OrderedDict) assert llm.state_dict()["lm_head.weight"].shape == torch.Size([50304, 128]) def test_save_method(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) target_dir = "saved_model" llm.save(target_dir) expected_files = [ "config.json", "generation_config.json", "lit_model.pth", "model_config.yaml", "prompt_style.yaml", "tokenizer_config.json", "tokenizer.json", ] files_in_directory = os.listdir(target_dir) for file_name in expected_files: assert file_name in files_in_directory, f"{file_name} is missing from {target_dir}" def test_forward_method(tmp_path): with patch("torch.backends.mps.is_available", return_value=USE_MPS): llm = LLM.load( model="EleutherAI/pythia-14m", ) inputs = torch.ones(6, 128, dtype=torch.int64).to(next(llm.model.parameters()).device) assert llm(inputs).shape == torch.Size([6, 128, 50304]) logits, loss = llm(inputs, target_ids=inputs) assert logits.shape == torch.Size([6, 128, 50304]) assert isinstance(loss.item(), float) @skip_in_ci_on_macos # The macOS CI machine segfaults here (it works fine locally though) def test_precision_selection(tmp_path): llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained") llm.distribute(precision="16-true") assert llm.model._forward_module.lm_head.weight.dtype == torch.float16, ( f"Expected float16, but got {llm.model._forward_module.lm_head.weight.dtype}" ) ================================================ FILE: tests/test_args.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import pytest from litgpt.args import TrainArgs def test_compute_warmup_iters(): # warmup disabled train = TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=0) assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 0 # lr_warmup_steps and lr_warmup_fraction both are not allowed with pytest.raises(ValueError, match="Can't provide both `--train.lr_warmup_fraction`"): TrainArgs(lr_warmup_steps=1, lr_warmup_fraction=0.2) # lr_warmup_fraction invalid range with pytest.raises(ValueError, match=" must be between 0 and 1"): TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=1.1) # lr_warmup_steps train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=0) assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 100 # lr_warmup_steps multiplied by accumulation factor train.global_batch_size = 4 assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 400 assert train.warmup_iters(devices=2, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 200 # lr_warmup_steps truncated by max iters assert train.warmup_iters(devices=1, num_nodes=1, max_iters=120, train_dataloader=range(10)) == 120 # lr_warmup_fraction train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=0, lr_warmup_fraction=0.3) assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(100)) == 30 # lr_warmup_fraction truncated by max iters assert train.warmup_iters(devices=1, num_nodes=1, max_iters=20, train_dataloader=range(100)) == 20 # lr_warmup_fraction rounds up assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(5)) == 2 ================================================ FILE: tests/test_batch.py ================================================ import warnings from pathlib import Path import lightning as L import pytest import torch import litgpt from litgpt.api import GPT, LLM from litgpt.generate.base import ( batched_generate_fn, batched_next_token, generate_fn, next_token, ) from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf warnings.filterwarnings("ignore") def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]: L.seed_everything(42) model_name = "microsoft/phi-2" download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) llm: LLM = LLM.load( model_name, tokenizer_dir=Path(tmp_path / model_name), init="random", ) model: GPT = llm.model model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length, device=device) return llm, model @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires a GPU.") def test_batched_equivalence(tmp_path): model_name = "microsoft/phi-2" download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) device = "cuda:0" batch_size = 3 sample_kwargs = {"top_k": 1} llm: LLM = LLM.load( model_name, tokenizer_dir=Path(tmp_path / model_name), init="random", ) model: GPT = llm.model model.set_kv_cache(batch_size=1, max_seq_length=50, device=device) input_pos_1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device) input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device) x = torch.tensor( [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], device=device, dtype=torch.int64, ) batch_x1 = torch.stack([x] * batch_size, dim=0) # Single token generation baseline tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs) tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs) assert tok_1.ndim == 1 assert tok_2.ndim == 1 assert tok_1.size(0) == 1 assert tok_2.size(0) == 1 # Switch to batched generation model.clear_kv_cache() model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0") toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs) toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs) assert toks_1.ndim == 2 assert toks_2.ndim == 2 assert toks_1.size(0) == batch_size assert toks_2.size(0) == batch_size # Assert that single and batched next token generation are equivalent assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}" assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}" @_RunIf(min_cuda_gpus=1) def test_simple_batch(): old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = False config = litgpt.Config.from_name("microsoft/phi-2", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256) with torch.device("cuda"): m = litgpt.GPT(config).requires_grad_(False).eval() x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]]) input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]]) x1 = torch.tensor([[1], [2]]) input_pos1 = torch.tensor([[4], [3]]) with torch.device("cuda"): m.set_kv_cache(2) outs0 = m(x0, input_pos0) outs1 = m(x1, input_pos1) with torch.device("cuda"): m.set_kv_cache(1) outs0_ref0 = m(x0[:1], input_pos0[0]) outs1_ref0 = m(x1[:1], input_pos1[0]) with torch.device("cuda"): m.set_kv_cache(1) outs0_ref1 = m(x0[1:], input_pos0[1]) outs1_ref1 = m(x1[1:], input_pos1[1]) outs0_ref = torch.cat([outs0_ref0, outs0_ref1]) outs1_ref = torch.cat([outs1_ref0, outs1_ref1]) print(outs0_ref - outs0) print(outs0.shape) torch.testing.assert_close(outs0, outs0_ref) torch.testing.assert_close(outs1, outs1_ref) torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32 @_RunIf(min_cuda_gpus=1) def test_batch_generate(tmp_path): torch.use_deterministic_algorithms(True) device = "cuda:0" batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) batch_x = torch.tensor( [ [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596], [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596], ], device=device, dtype=torch.int64, ) # Generate tokens tokens = [] for l in batched_generate_fn( model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, include_prompt=True, include_eos=False, ): tokens.append([t.item() if t is not None else None for t in l]) def find_unique_stop(triplets): # Initialize a dictionary to count all number occurrences number_count = {} # Count occurrences of each number across all positions for triplet in triplets: for num in triplet: number_count[num] = number_count.get(num, 0) + 1 # Initialize lists to store unique numbers for each position unique_first = [] unique_second = [] unique_third = [] # Check each triplet for a, b, c in triplets: if number_count[a] == 1: unique_first.append(a) if number_count[b] == 1: unique_second.append(b) if number_count[c] == 1: unique_third.append(c) import random # Seeded earlier random.shuffle(unique_first) random.shuffle(unique_second) random.shuffle(unique_third) return [unique_first[0], unique_second[0], unique_third[0]] # Now that we know the randomly generated tokens, sample some tokens to stop each stream at. stops = find_unique_stop(tokens[batch_x.size(1) :]) first_stream = [t[0] for t in tokens if t[0] is not None] second_stream = [t[1] for t in tokens if t[1] is not None] third_stream = [t[2] for t in tokens if t[2] is not None] # Let's slice the streams at the stop tokens. stop_idxes = [ first_stream.index(stops[0]), second_stream.index(stops[1]), third_stream.index(stops[2]), ] # While we're at it, grab the last token that would be generated before stopping. last_tokens = [ first_stream[stop_idxes[0] - 1], second_stream[stop_idxes[1] - 1], third_stream[stop_idxes[2] - 1], ] for t in tokens: print(t) # Now we generate again, stopping early at the stop tokens. tokens = [] for l in batched_generate_fn( model, prompts=batch_x, max_returned_tokens=50, stop_tokens=[(s,) for s in stops], sample_args=sample_kwargs, include_prompt=True, include_eos=False, ): tokens.append([t.item() if t is not None else None for t in l]) # Finally, assert that the streams are correct. first_stream = [t[0] for t in tokens if t[0] is not None] print(first_stream) print(len(first_stream), stop_idxes[0]) assert len(first_stream) == stop_idxes[0] assert first_stream[-1] == last_tokens[0] second_stream = [t[1] for t in tokens if t[1] is not None] print(second_stream) print(len(second_stream), stop_idxes[1]) assert len(second_stream) == stop_idxes[1] assert second_stream[-1] == last_tokens[1] third_stream = [t[2] for t in tokens if t[2] is not None] print(third_stream) print(len(third_stream), stop_idxes[2]) assert len(third_stream) == stop_idxes[2] assert third_stream[-1] == last_tokens[2] torch.use_deterministic_algorithms(False) # for t in llm.tokenizer.decode_stream([torch.tensor(i) for i in first_stream]): # print(t, end="", flush=True) # print() @_RunIf(min_cuda_gpus=1) def test_batch_generate_equivalence(tmp_path): torch.use_deterministic_algorithms(True) device = "cuda:0" batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) batch_x = torch.tensor( [ [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410], [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596], [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596], ], device=device, dtype=torch.int64, ) # The other test tests the stop_tokens functionality much more exhaustively, we'll just generate and compare 50 tokens here. batch_tokens = [] for l in batched_generate_fn( model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, include_prompt=False, include_eos=False, ): batch_tokens.append([t.item() if t is not None else None for t in l]) first_stream = [t[0] for t in batch_tokens if t[0] is not None] batch_size = 1 llm, model = create_llm(tmp_path, batch_size, 50, device) tokens = [] for t in generate_fn( model, prompt=batch_x[0], max_returned_tokens=50, include_prompt=False, include_eos=False, **sample_kwargs, ): if t.size(0) == 1: tokens.append(t.item()) else: tokens.extend(t.tolist()) torch.use_deterministic_algorithms(False) # TODO: (apaz-cli) This consistency test doesn't actually work at the moment. It's inconsistent. # The output is really close... Something is going on here. For the moment, maybe this is close enough? # Enough at least that we can start prototyping. print(first_stream) print(tokens) # assert first_stream == tokens ================================================ FILE: tests/test_chat.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import re import subprocess import sys from contextlib import redirect_stderr, redirect_stdout from io import StringIO from itertools import repeat from pathlib import Path from typing import Iterable, Iterator from unittest.mock import MagicMock, Mock, patch import pytest import torch import yaml import litgpt.chat.base as chat import litgpt.generate.base as generate from litgpt import Config, Tokenizer from litgpt.utils import auto_download_checkpoint, save_config skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.", ) @pytest.mark.parametrize( ("generated", "stop_tokens", "expected"), [ (repeat(1), (), [1] * 8), ([1, 2, 3, 0], ([0],), [1, 2, 3]), ([1, 2, 3, 0], ([9], [2, 4], [1, 2, 3, 0]), []), ([1, 2, 3, 0, 0], ([0, 0, 0], [0, 0]), [1, 2, 3]), ([3, 1, 2], ([1, 2], [3]), []), ([1, 2, 3, 0, 3, 2, 1, 0], ([4, 3, 2, 1], [2, 4]), [1, 2, 3, 0, 3, 2, 1, 0]), ], ) def test_generate(monkeypatch, generated, stop_tokens, expected): import lightning as L L.seed_everything(1234) input_idx = torch.tensor([5, 3]) max_returned_tokens = len(input_idx) + 8 model = MagicMock() model.config.block_size = 100 model.max_seq_length = 100 it = iter(generated) def multinomial(*_, **__): out = next(it) return torch.tensor([out]) monkeypatch.setattr(generate, "multinomial_num_samples_1", multinomial) actual = chat.generate(model, input_idx, max_returned_tokens, stop_tokens=stop_tokens) actual = list(actual) assert len(actual) == len(expected), (actual, expected) if not actual: assert actual == expected, (actual, expected) else: for t in actual: assert t.dtype == torch.long, t.dtype actual_list = torch.cat(actual).tolist() assert actual_list == expected, (actual_list, expected) def test_decode(): checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m") tokenizer = Tokenizer(checkpoint_dir) text = ( "Hello World! This a bunch of text. Lorem ipsum dolor sit amet, " "consectetur adipiscing elit, sed do eiusmod tempor incididunt " "ut labore et dolore magna aliqua." ) encoded: torch.Tensor = tokenizer.encode(text) encoded_stream: Iterable[torch.Tensor] = torch.tensor_split(encoded, encoded.shape[0], dim=0) decoded_stream: Iterator[str] = tokenizer.decode_stream(encoded_stream) decoded: str = "".join(decoded_stream) # Note that encoded and decoded text will not always be character for character identical.abs # Indeed, sometimes it is not. But that tends to be because of special cases, and this is not # one of those. assert text == decoded, (text, decoded) @skip_in_ci_on_macos @patch("litgpt.chat.base.input") @pytest.mark.parametrize("stop_iteration", [KeyboardInterrupt, ""]) def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, tensor_like): # these values will be iteratively provided for each `input()` call mocked_input.side_effect = ["Hello", stop_iteration] config_path = fake_checkpoint_dir / "model_config.yaml" config = { "name": "Llama 3", "block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1, } config_path.write_text(yaml.dump(config)) load_mock = Mock() load_mock.return_value = load_mock monkeypatch.setattr(chat, "load_checkpoint", load_mock) tokenizer_mock = Mock() tokenizer_mock.return_value.backend = "sentencepiece" tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3]) tokenizer_mock.return_value.decode_stream.return_value = "foo bar baz" monkeypatch.setattr(chat, "Tokenizer", tokenizer_mock) generate_mock = MagicMock() generate_mock.__iter__.return_value = [torch.tensor([3, 2, 1])] monkeypatch.setattr(chat, "generate", generate_mock) out, err = StringIO(), StringIO() with redirect_stdout(out), redirect_stderr(err): chat.main(temperature=2.0, max_new_tokens=10, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir) # decoding is done per each generated item assert len(tokenizer_mock.return_value.decode_stream.mock_calls) == 1 assert tokenizer_mock.return_value.decode_stream.call_args[0][0] is generate_mock.return_value # Now a Mock # Assert that the generated result is printed to stdout assert re.match(r".*Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL), out.getvalue() def test_cli(): args = ["litgpt", "chat", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Chat with a model" in output @skip_in_ci_on_macos @patch("litgpt.chat.base.input") @patch("litgpt.chat.base.merge_lora") def test_merge_lora_if_needed(mocked_merge_lora, mocked_input, fake_checkpoint_dir, monkeypatch, tensor_like): # these values will be iteratively provided for each `input()` call mocked_input.side_effect = [""] # pretend there is an unmerged LORA checkpoint os.rename(fake_checkpoint_dir / "lit_model.pth", fake_checkpoint_dir / "lit_model.pth.lora") mocked_merge_lora.side_effect = lambda _: Path(fake_checkpoint_dir / "lit_model.pth").touch() config = Config.from_name("pythia-14m") save_config(config, fake_checkpoint_dir) monkeypatch.setattr(chat, "load_checkpoint", Mock()) monkeypatch.setattr(chat, "Tokenizer", Mock()) out, err = StringIO(), StringIO() with redirect_stdout(out), redirect_stderr(err): chat.main(checkpoint_dir=fake_checkpoint_dir) assert re.match(r".*Merging LoRA weights with the base model\..*", out.getvalue(), re.DOTALL) mocked_merge_lora.assert_called_once() @skip_in_ci_on_macos def test_litgpt_chat_endtoend(): from litgpt.chat.base import main checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m") # Patch input() and redirect stdout. Raise to exit the repl. simulated_input = Mock(side_effect=["input", KeyboardInterrupt]) captured_output = StringIO() with patch("builtins.input", simulated_input): with redirect_stdout(captured_output): try: main(checkpoint_dir=checkpoint_dir, max_new_tokens=256, top_k=1) except KeyboardInterrupt: pass # pythia-14m is not instruct-tuned, so it does not give an "answer" per se, but a continuation. output = captured_output.getvalue() assert ">> Reply: " in output, f"Expected reply not found. Got:\n{output}" # Verify the model actually generated some text after the reply prompt reply_start = output.index(">> Reply: ") + len(">> Reply: ") assert len(output[reply_start:].strip()) > 0, f"Expected non-empty reply. Got:\n{output}" assert simulated_input.call_count == 2 @skip_in_ci_on_macos def test_litgpt_generate_endtoend(): from litgpt.generate.base import main checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m") captured_output = StringIO() with redirect_stdout(captured_output): try: main(checkpoint_dir=checkpoint_dir, prompt="Hello World", max_new_tokens=256, top_k=1) except KeyboardInterrupt: pass # pythia-14m is not instruct-tuned, so it does not give an "answer" per se, but a continuation. assert "Hello World!" in captured_output.getvalue(), ( f"Expected output not found. Got:\n{captured_output.getvalue()}" ) ================================================ FILE: tests/test_ci.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE from litgpt.utils import _RunIf @_RunIf(min_cuda_gpus=1) def test_gpu_ci_installs_bitsandbytes(): assert _BITSANDBYTES_AVAILABLE, str(_BITSANDBYTES_AVAILABLE) ================================================ FILE: tests/test_cli.py ================================================ import sys from contextlib import redirect_stdout from io import StringIO from unittest import mock import pytest from packaging.version import Version from litgpt.__main__ import main def test_cli(): out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "-h"]): main() out = out.getvalue() assert "usage: litgpt" in out assert ( "{download,chat,finetune,finetune_lora,finetune_full,finetune_adapter,finetune_adapter_v2," "pretrain,generate,generate_full,generate_adapter,generate_adapter_v2,generate_sequentially," "generate_speculatively,generate_tp,convert_to_litgpt,convert_from_litgpt,convert_pretrained_checkpoint," "merge_lora,evaluate,serve}" in out ) assert ( """Available subcommands: download Download weights or tokenizer data from the Hugging Face Hub. chat Chat with a model.""" in out ) assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out assert """serve Serve a LitGPT model using LitServe.""" in out out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune_lora", "-h"]): main() out = out.getvalue() assert ( """--lora_alpha LORA_ALPHA The LoRA alpha. (type: int, default: 16)""" in out ) if Version(f"{sys.version_info.major}.{sys.version_info.minor}") < Version("3.9"): # python 3.8 prints `Union[int, null]` instead of `Optional[int]` return out = StringIO() with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "pretrain", "-h"]): main() out = out.getvalue() print(out) assert ( """--train.max_tokens MAX_TOKENS Total number of tokens to train on (type: Optional[int], default: 3000000000000)""" in out ) def test_pretrain_allows_max_steps(): # Ensure --train.max_steps is accepted by the CLI for pretrain # and only emits a warning instead of raising a validation error. args = [ "litgpt", "pretrain", "pythia-14m", "--train.max_steps=1", "--out_dir=out/test-cli", ] with pytest.warns(UserWarning, match="max_steps"): try: with mock.patch("sys.argv", args): main() except Exception: pass def test_rewrite_finetune_command(): out1 = StringIO() with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]): main() out2 = StringIO() with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune_lora", "-h"]): main() assert out1.getvalue() == out2.getvalue() ================================================ FILE: tests/test_config.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import pytest import yaml import litgpt.config as config_module from litgpt import Config from litgpt.config import find_multiple def test_config(): config = Config() assert config.name == "" assert config.block_size == 4096 config = Config(block_size=2048) assert config.block_size == 2048 config = Config.from_name("pythia-14m") assert config.block_size == 512 config = Config.from_name("pythia-14m", block_size=4096) assert config.block_size == 4096 config = Config(hf_config={"name": "pythia-14m"}) assert config.name == "pythia-14m" def test_from_hf_name(): # by short-hand name config0 = Config.from_name("tiny-llama-1.1b") # or by huggingface hub repo name config1 = Config.from_name("TinyLlama-1.1B-intermediate-step-1431k-3T") assert config0 is not None assert config1 is not None assert config0 == config1 def test_nonexisting_name(): with pytest.raises(ValueError, match="'invalid-model-name' is not a supported config name"): Config.from_name("invalid-model-name") @pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs]) def test_short_and_hf_names_are_equal_unless_on_purpose(config): # by short-hand name config0 = Config.from_name(config["name"]) # or by huggingface hub repo name config1 = Config.from_name(config["hf_config"]["name"]) assert config0.name == config1.name def test_from_hf_name_with_org_string(): # Test case 1: valid input config0 = Config.from_name("tiny-llama-1.1b") config1 = Config.from_name("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T") assert config0 is not None assert config1 is not None assert config0 == config1 # Test case 2: invalid input - org not found with pytest.raises( ValueError, match="'UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T' is not a supported config name" ): Config.from_name("UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T") # Test case 3: invalid input - name not found with pytest.raises(ValueError, match="'TinyLlama/TinyLlama-XYZ' is not a supported config name"): Config.from_name("TinyLlama/TinyLlama-XYZ") def test_from_checkpoint(tmp_path): # 1. Neither `lit_config.py` nor matching config exists. with pytest.raises(FileNotFoundError, match="neither 'model_config.yaml' nor matching config exists"): Config.from_checkpoint(tmp_path / "non_existing_checkpoint") # 2. If `lit_config.py` doesn't exists, but there is a matching config in `litgpt/config.py`. config = Config.from_checkpoint(tmp_path / "pythia-14m") assert config.name == "pythia-14m" assert config.block_size == 512 assert config.n_layer == 6 # 3. If only `lit_config.py` exists. config_data = {"name": "pythia-14m", "block_size": 24, "n_layer": 2} with open(tmp_path / "model_config.yaml", "w", encoding="utf-8") as file: yaml.dump(config_data, file) config = Config.from_checkpoint(tmp_path) assert config.name == "pythia-14m" assert config.block_size == 24 assert config.n_layer == 2 # 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config (tmp_path / "pythia-14m").mkdir() with open(tmp_path / "pythia-14m/model_config.yaml", "w", encoding="utf-8") as file: yaml.dump(config_data, file) config = Config.from_checkpoint(tmp_path / "pythia-14m") assert config.name == "pythia-14m" assert config.block_size == 24 assert config.n_layer == 2 @pytest.mark.parametrize("head_size", [None, 128]) def test_head_size(head_size): config = Config(head_size) assert config.head_size == head_size or config.n_embd // config.n_head def test_find_multiple(): assert find_multiple(17, 5) == 20 assert find_multiple(30, 7) == 35 assert find_multiple(10, 2) == 10 assert find_multiple(5, 10) == 10 assert find_multiple(50254, 128) == 50304 assert find_multiple(50254, 256) == 50432 assert find_multiple(50254, 512) == 50688 ================================================ FILE: tests/test_config_hub.py ================================================ import importlib import importlib.util from pathlib import Path from unittest import mock from unittest.mock import Mock import pytest from lightning.fabric.plugins import Precision from litgpt import Config from litgpt.utils import CLI fixed_pairs = [ ("litgpt/pretrain.py", "pretrain/debug.yaml"), ("litgpt/pretrain.py", "pretrain/tinyllama.yaml"), ("litgpt/pretrain.py", "pretrain/tinystories.yaml"), ( "litgpt/pretrain.py", "https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/tinystories.yaml", ), ] config_hub_path = Path(__file__).parent.parent / "config_hub" / "finetune" model_pairs = [] for model_dir in config_hub_path.iterdir(): if model_dir.is_dir(): model_name = model_dir.name for yaml_file in model_dir.glob("*.yaml"): config_name = yaml_file.stem python_file = "litgpt/finetune/full.py" if config_name == "full" else "litgpt/finetune/lora.py" relative_yaml_path = yaml_file.relative_to(config_hub_path.parent) model_pairs.append((python_file, str(relative_yaml_path))) all_pairs = fixed_pairs + model_pairs @pytest.mark.parametrize(("script_file", "config_file"), all_pairs) def test_config_help(script_file, config_file, monkeypatch): """Test that configs validate against the signature in the scripts.""" script_file = Path(__file__).parent.parent / script_file assert script_file.is_file() if "http" not in str(config_file): config_file = Path(__file__).parent.parent / "config_hub" / config_file assert config_file.is_file() spec = importlib.util.spec_from_file_location(str(script_file.parent.name), script_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) monkeypatch.setattr(module, "main", Mock()) monkeypatch.setattr(module, "Tokenizer", Mock()) monkeypatch.setattr(module, "BitsandbytesPrecision", Mock(return_value=Precision()), raising=False) monkeypatch.setattr(module, "Config", Mock(return_value=Config.from_name("pythia-14m"))) monkeypatch.setattr(module, "check_valid_checkpoint_dir", Mock(), raising=False) try: with mock.patch("sys.argv", [script_file.name, "--config", str(config_file), "--devices", "1"]): CLI(module.setup) module.main.assert_called_once() except FileNotFoundError: pass # FileNotFound occurs here because we have not downloaded the model weights referenced in the config files # which is ok because here we just want to validate the config file itself. ================================================ FILE: tests/test_deepseek_moe.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import pytest import torch from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from litgpt import Config from litgpt.model import GPT, LLaMAMLP @torch.inference_mode() @pytest.mark.parametrize("batch_size", (1, 2)) @pytest.mark.parametrize("seq_len", (8, 16)) @pytest.mark.parametrize("device", [torch.device("cpu")]) def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device): """Test MOE litgpt vs hf""" config_litgpt = Config( padded_vocab_size=10000, n_layer=2, vocab_size=10000, n_embd=64, n_head=4, n_query_groups=4, head_size=16, norm_eps=1e-6, bias=False, latent_attention={ "q_lora_rank": 32, "kv_lora_rank": 16, "qk_rope_head_dim": 8, "qk_nope_head_dim": 8, "v_head_dim": 16, }, n_expert=16, n_shared_expert=1, n_expert_per_token=2, n_expert_groups=4, n_topk_groups=2, n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2` first_k_dense_replace=1, routed_scaling_factor=2.5, norm_topk_prob=True, moe_intermediate_size=20, mlp_class_name="LLaMAMoE", ) config_hf = DeepseekV3Config( padded_vocab_size=10000, num_hidden_layers=2, vocab_size=10000, hidden_size=64, num_attention_heads=4, num_key_value_heads=4, q_lora_rank=32, kv_lora_rank=16, qk_rope_head_dim=8, qk_nope_head_dim=8, v_head_dim=16, rope_interleave=False, first_k_dense_replace=1, routed_scaling_factor=2.5, norm_topk_prob=True, n_routed_experts=config_litgpt.n_expert, n_shared_experts=config_litgpt.n_shared_expert, num_experts_per_tok=config_litgpt.n_expert_per_token, n_group=config_litgpt.n_expert_groups, topk_group=config_litgpt.n_topk_groups, moe_intermediate_size=config_litgpt.moe_intermediate_size, ) model_litgpt = GPT(config_litgpt).to(device) model_litgpt.apply(model_litgpt._init_weights) mlp_litgpt = model_litgpt.transformer.h[0].mlp assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1) moe_litgpt = model_litgpt.transformer.h[1].mlp model_hf = DeepseekV3ForCausalLM(config_hf).to(device) moe_hf = model_hf.model.layers[1].mlp moe_litgpt.eval() moe_hf.eval() sync_weights(moe_litgpt, moe_hf) hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) output_litgpt = moe_litgpt(hidden_states) output_hf = moe_hf(hidden_states) assert torch.allclose(output_litgpt, output_hf, atol=1e-5) def sync_weights(litgpt_model, hf_model): print("Synchronizing MoE weights...") with torch.no_grad(): if hasattr(litgpt_model, "gate"): if hasattr(litgpt_model.gate, "weight"): hf_model.gate.weight.copy_(litgpt_model.gate.weight) if hasattr(litgpt_model.gate, "e_score_correction_bias"): hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias) for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)): hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight) hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight) hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight) if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"): hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight) hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight) hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight) print("MoE weight synchronization complete.") ================================================ FILE: tests/test_distributed.py ================================================ import pytest import torch from lightning import Fabric from litgpt.utils import _RunIf @_RunIf(min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize("strategy", ["ddp", "fsdp"]) def test_no_backward_sync(strategy): fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) fabric.launch() # account for sharding in the case of FSDP out_features = 1 if "ddp" in strategy else fabric.world_size model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device) x = torch.randn(1, 1, device=fabric.device) model = fabric.setup(model) # 6 iters, 3 grad accumulation iters for i, enabled in enumerate((True, True, False, True, True, False), 1): x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) with fabric.no_backward_sync(model, enabled): y = model(x) fabric.backward(y.sum()) if not enabled: # Math for the first 3 iters # # DistributedDataParallel # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9 # ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^ # rank0 rank1 allreduce # # thunder.distributed.ddp # ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9 # ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^ # rank0 rank1 allreduce1 rank0 rank1 allreduce2 assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) model.weight.grad = None ================================================ FILE: tests/test_evaluate.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import subprocess from contextlib import redirect_stdout from dataclasses import asdict from io import StringIO from unittest import mock import pytest import torch import yaml import litgpt.eval.evaluate as module from litgpt import GPT, Config from litgpt.scripts.download import download_from_hub @pytest.mark.flaky(reruns=3) def test_evaluate_script(tmp_path): ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) checkpoint_dir = tmp_path / "EleutherAI" / "pythia-14m" ours_model = GPT(ours_config) torch.save(ours_model.state_dict(), checkpoint_dir / "lit_model.pth") with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]): with pytest.raises(ValueError) as excinfo: module.convert_and_evaluate( checkpoint_dir, out_dir=tmp_path / "out_dir", device=None, dtype=torch.float32, limit=5, tasks="logiqa", batch_size=0, # Test for non-positive integer ) assert "batch_size must be a positive integer, 'auto', or in the format 'auto:N'." in str(excinfo.value) with pytest.raises(ValueError) as excinfo: module.convert_and_evaluate( checkpoint_dir, out_dir=tmp_path / "out_dir", device=None, dtype=torch.float32, limit=5, tasks="logiqa", batch_size="invalid", # Test for invalid string ) assert "batch_size must be a positive integer, 'auto', or in the format 'auto:N'." in str(excinfo.value) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]): module.convert_and_evaluate( checkpoint_dir, out_dir=tmp_path / "out_dir", device=None, dtype=torch.float32, limit=5, tasks="logiqa", batch_size=1, # Valid case ) stdout = stdout.getvalue() assert (tmp_path / "out_dir" / "results.json").is_file() assert "logiqa" in stdout assert "Metric" in stdout assert "Loading checkpoint shards" not in stdout def test_cli(): args = ["litgpt", "evaluate", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Evaluate a model with the LM Evaluation Harness" in output ================================================ FILE: tests/test_full.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout from io import StringIO from unittest import mock from unittest.mock import Mock import torch import yaml import litgpt.finetune.full as module from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) monkeypatch.setattr(module, "load_checkpoint", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) out_dir = tmp_path / "out" setup_args = (fake_checkpoint_dir,) setup_kwargs = dict( data=Alpaca(download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0), out_dir=out_dir, precision="32-true", train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]): module.setup(*setup_args, **setup_kwargs) out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: assert set(os.listdir(out_dir / checkpoint_dir)) == { "lit_model.pth", "model_config.yaml", "tokenizer_config.json", "tokenizer.json", "hyperparameters.yaml", "prompt_style.yaml", } assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() logs = stdout.getvalue() assert logs.count("(step)") == 6 assert logs.count("val loss") == 4 # 3 validations + 1 final validation assert logs.count("Final evaluation") == 1 assert "of trainable parameters: 1,888" in logs # Resume training and do 2 steps more setup_kwargs["train"].max_steps = 8 setup_kwargs["resume"] = True stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]): module.setup(*setup_args, **setup_kwargs) logs = stdout.getvalue() assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs assert logs.count("(step)") == 2 assert out_dir / "step-000008" in set(out_dir.iterdir()) ================================================ FILE: tests/test_generate_speculatively.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import re import subprocess from contextlib import redirect_stderr, redirect_stdout from io import StringIO from unittest.mock import ANY, Mock, call import pytest import torch import yaml from torch import nn import litgpt.generate.speculative_decoding as generate from litgpt import GPT, Config from litgpt.utils import _RunIf def test_speculative_decoding_target_never_accepts_draft_tokens(): class DraftModel(nn.Module): def forward(self, **kwargs): return torch.tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C) class TargetModel(nn.Module): def forward(self, idx, **kwargs): _, T = idx.shape return torch.tensor([[0, 0, 0, 0, 0, 6, 7, 8, 9, 10]] * T, dtype=torch.float)[None, ...] # (B, T, C) draft_model = DraftModel() target_model = TargetModel() token = torch.tensor([-1]) input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs ) # target model never accepts draft model's output, thus the output of the `speculative_decoding` # is a single token sampled from the target model assert len(output) == 1 assert output > 5 def test_speculative_decoding_target_always_accepts_draft_tokens(): class DraftModel(nn.Module): def forward(self, **kwargs): return torch.tensor([0, 0, 3, 4, 5, 6, 7, 8, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C) class TargetModel(nn.Module): def forward(self, idx, **kwargs): _, T = idx.shape return torch.tensor([[0, 0, 3, 4, 5, 6, 7, 8, 0, 0]] * T, dtype=torch.float)[None, ...] # (B, T, C) draft_model = DraftModel() target_model = TargetModel() token = torch.tensor([-1]) input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs ) # target model always accepts draft model's output, thus the output of the `speculative_decoding` # is 4 tokens (3 accepted draft tokens + 1 sampled from target model's output) assert len(output) == 4 assert torch.all((output >= 3) & (output <= 8)) def test_speculative_decoding_target_sometimes_accepts_draft_tokens(): class DraftModel(nn.Module): def forward(self, **kwargs): return torch.tensor([0, 0, 3, 4, 10, 9, 7, 8, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C) class TargetModel(nn.Module): def forward(self, idx, **kwargs): return torch.tensor( [ [0, 0, 0, 0, 10, 9, 0, 0, 0, 0], [0, 0, 0, 0, 10, 9, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 10], ], dtype=torch.float, )[None, ...] # (B, T, C) draft_model = DraftModel() target_model = TargetModel() token = torch.tensor([-1]) input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs ) # target model accepts only 2 out of 3 draft model's output, thus the output of the `speculative_decoding` # is 3 tokens (2 accepted draft tokens + 1 sampled from adjusted distribution) assert len(output) == 3 assert torch.equal(output, torch.tensor([4, 4, 9])) @pytest.mark.parametrize("max_seq_length", (10, 15, 20, 25)) @pytest.mark.parametrize("speculative_k", (1, 2, 3)) def test_generate(max_seq_length, speculative_k): # create a prompt T = 5 input_idx = torch.arange(0, T) max_new_tokens = max_seq_length - T # prepare models draft_model = GPT(Config(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8)) target_model = GPT(Config(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16)) for model in (draft_model, target_model): model.max_seq_length = max_seq_length model.set_kv_cache(batch_size=1) # generate tokens out, acceptance_rate = generate.generate( draft_model, target_model, input_idx, T + max_new_tokens, top_k=1, speculative_k=speculative_k ) # validate assert out.size(0) == T + max_new_tokens - 1, (out.size(0), T + max_new_tokens - 1) assert 0.0 <= acceptance_rate <= 1.0 @_RunIf(min_cuda_gpus=1) # speculative decoding makes sense only on a GPU def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): # prepare configs for draft and target models draft_model_dir = fake_checkpoint_dir / "draft_model" draft_model_dir.mkdir() target_model_dir = fake_checkpoint_dir / "target_model" target_model_dir.mkdir() draft_model_config = dict(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8) target_model_config = dict(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16) (draft_model_dir / "model_config.yaml").write_text(yaml.dump(draft_model_config)) (target_model_dir / "model_config.yaml").write_text(yaml.dump(target_model_config)) # create empty files required for validation for model_dir in (draft_model_dir, target_model_dir): (model_dir / "tokenizer.json").touch() (model_dir / "tokenizer_config.json").touch() (model_dir / "lit_model.pth").touch() # moke functions module_mock = Mock() module_mock.config.block_size = 128 load_mock = Mock() load_mock.return_value = load_mock monkeypatch.setattr(generate, "load_checkpoint", load_mock) tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3]) tokenizer_mock.return_value.decode.return_value = "foo bar baz" monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() generated_tokens = torch.tensor([3, 2, 1]) acceptance_rate = 0.0 generate_mock.return_value = (generated_tokens, acceptance_rate) monkeypatch.setattr(generate, "generate", generate_mock) # do the sampling num_samples = 2 out, err = StringIO(), StringIO() with redirect_stdout(out), redirect_stderr(err): generate.main( draft_model_checkpoint_dir=draft_model_dir, target_model_checkpoint_dir=target_model_dir, temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, ) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value[0]) assert ( generate_mock.mock_calls == [ call( ANY, ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, stop_tokens=[tokenizer_mock.return_value.eos_id], speculative_k=3, ) ] * num_samples ) expected_output = "foo bar baz\nAcceptance rate: 0.00%\n" * num_samples # Allow for the config to be printed before the expected repeated strings. pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) err_value = err.getvalue() expected_parts = [ "'padded_vocab_size': 512", "'n_layer': 2", "'n_head': 4", ] assert all(part in err_value for part in expected_parts) def test_cli(): args = ["litgpt", "generate_speculatively", "-h"] output = subprocess.check_output(args) output = str(output.decode()) assert "Default generation option" in output ================================================ FILE: tests/test_lora.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout from copy import deepcopy from io import StringIO from itertools import product from unittest import mock from unittest.mock import Mock import pytest import torch import yaml from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging from torch.distributed.device_mesh import init_device_mesh from torch.nn import functional as F from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import litgpt.config as config_module import litgpt.finetune.lora as module from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import GPT as LoRAGPT from litgpt.lora import ( CausalSelfAttention, Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights, ) from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf def test_lora_layer_replacement(): config = Config(n_layer=2, n_head=4, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1) model = LoRAGPT(config) assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention) assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention) assert isinstance(model.lm_head, LoRALinear) assert isinstance(model.transformer.h[0].mlp.proj, LoRALinear) def test_lora_merge(): config = Config( n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, lora_projection=True, ) model = LoRAGPT(config) model.train() attn_proj = model.transformer.h[0].attn.proj initial_weight = attn_proj.linear.weight.clone() assert torch.equal(attn_proj.linear.weight, initial_weight) # perform an update to the LoRA weights mark_only_lora_as_trainable(model) optimizer = torch.optim.SGD(model.parameters(), lr=1.0) y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64)) y.sum().backward() optimizer.step() optimizer.zero_grad() # the weight remains unchanged (only lora A and B change) assert torch.equal(attn_proj.linear.weight, initial_weight) # calling merge() multiple times in a row should not merge multiple times merge_lora_weights(model) assert attn_proj.merged weight_after = attn_proj.linear.weight.clone() merge_lora_weights(model) merge_lora_weights(model) assert torch.equal(attn_proj.linear.weight, weight_after) # check that `W_after = W_initial + (A x B)` delta_w = attn_proj.get_lora_AB() torch.testing.assert_close(weight_after, initial_weight + delta_w) def test_lora_mqa_gqa(): # MHA config = Config( n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, ) assert config.n_query_groups == config.n_head model = LoRAGPT(config) attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 24) bsz, ctx_len, in_dim = 2, 30, 8 x_in = torch.randn(bsz, ctx_len, in_dim) out = attn(x_in) non_lora_ind = list(set(range(24)).difference(lora_ind)) assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 # MQA config.n_query_groups = 1 model = LoRAGPT(config) attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11] assert attn.linear.weight.shape == (12, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (10, 2) assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 12) bsz, ctx_len, in_dim = 2, 30, 8 x_in = torch.randn(bsz, ctx_len, in_dim) out = attn(x_in) non_lora_ind = list(set(range(12)).difference(lora_ind)) assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 # GQA config.n_query_groups = 2 model = LoRAGPT(config) attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 16) bsz, ctx_len, in_dim = 2, 30, 8 x_in = torch.randn(bsz, ctx_len, in_dim) out = attn(x_in) non_lora_ind = list(set(range(16)).difference(lora_ind)) assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 @pytest.mark.parametrize( "n_head, n_query_groups, enable_lora", [ (4, 2, (True, False, True)), # GQA: Q+V only (4, 1, (False, True, True)), # MQA: K+V only (4, 2, (True, True, False)), # GQA: Q+K only (8, 2, (True, True, True)), # GQA: all enabled, different ratio (4, 4, (False, False, True)), # MHA: V only ], ) def test_lora_ind_correctness(n_head, n_query_groups, enable_lora): """Verify lora_ind correctly partitions Q, K, V regions using head_size-based sizes.""" n_embd = 16 config = Config( n_layer=1, n_head=n_head, n_embd=n_embd, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.0, lora_query=enable_lora[0], lora_key=enable_lora[1], lora_value=enable_lora[2], n_query_groups=n_query_groups, ) model = LoRAGPT(config) attn = model.transformer.h[0].attn.qkv head_size = n_embd // n_head q_size = head_size * n_head kv_size = head_size * n_query_groups expected_ind = [] if enable_lora[0]: expected_ind.extend(range(0, q_size)) if enable_lora[1]: expected_ind.extend(range(q_size, q_size + kv_size)) if enable_lora[2]: expected_ind.extend(range(q_size + kv_size, q_size + 2 * kv_size)) assert torch.equal(attn.lora_ind, torch.tensor(expected_ind)) # Verify zero_pad output dimension matches full QKV size total_qkv = q_size + 2 * kv_size lora_out_dim = sum(attn.qkv_shapes) x = torch.randn(1, 1, lora_out_dim) assert attn.zero_pad(x).shape[-1] == total_qkv def test_lora_filter(tmp_path): fabric = Fabric(devices=1) model = LoRAGPT.from_name("pythia-14m", n_layer=3, lora_r=1, lora_query=True, lora_value=True) save_path = tmp_path / "model.pth" fabric.save(save_path, {"model": model}, filter={"model": lora_filter}) saved = torch.load(save_path)["model"] expected = { "transformer.h.1.attn.qkv.lora_B", "transformer.h.2.attn.qkv.lora_B", "transformer.h.2.attn.qkv.lora_A", "transformer.h.1.attn.qkv.lora_A", "transformer.h.0.attn.qkv.lora_A", "transformer.h.0.attn.qkv.lora_B", } assert set(saved) == expected @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) monkeypatch.setattr(module, "load_checkpoint", Mock()) monkeypatch.setattr(module, "merge_lora", Mock()) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) out_dir = tmp_path / "out" stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), out_dir=out_dir, precision="32-true", train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { "lit_model.pth.lora", "model_config.yaml", "tokenizer_config.json", "tokenizer.json", "hyperparameters.yaml", "prompt_style.yaml", } assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() logs = stdout.getvalue() assert logs.count("(step)") == 6 assert logs.count("val loss") == 4 # 3 validations + 1 final validation assert logs.count("Final evaluation") == 1 assert "of trainable parameters: 512" in logs def test_lora_init_when_linear_overridden(): class MyLinear(torch.nn.Linear): def __init__(self, *args, **kwargs): # this needs to be implemented to demonstrate the failure super().__init__(*args, **kwargs) original_linear = torch.nn.Linear # Our bnb does this sort of monkey patching torch.nn.Linear = MyLinear layer = LoRAQKVLinear(1, 1, 1, 1, 1) assert isinstance(layer.linear, original_linear) torch.nn.Linear = original_linear @pytest.mark.parametrize( ("apply_to", "target_layer_names", "mlp_class_name"), ( ("lora_projection", "transformer.h.0.attn.proj", "GptNeoxMLP"), ("lora_mlp", {"transformer.h.0.mlp.fc", "transformer.h.0.mlp.proj"}, "GptNeoxMLP"), ("lora_head", "lm_head", "GptNeoxMLP"), ("lora_projection", "transformer.h.0.attn.proj", "LLaMAMLP"), ("lora_mlp", {"transformer.h.0.mlp.fc_1", "transformer.h.0.mlp.fc_2", "transformer.h.0.mlp.proj"}, "LLaMAMLP"), ("lora_head", "lm_head", "LLaMAMLP"), ), ) def test_lora_linear_utilization(apply_to, target_layer_names, mlp_class_name): config = Config( n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1, mlp_class_name=mlp_class_name, intermediate_size=8 * 3, **{apply_to: True}, ) model = LoRAGPT(config) state_dict = model.state_dict() if isinstance(target_layer_names, str): target_layer_names = {target_layer_names} lora_sublayers = (".lora_A", ".lora_B") # check that all the target layers have LoRA weights for layer_name in target_layer_names: for lora_sublayer in lora_sublayers: assert layer_name + lora_sublayer in state_dict # check that only target layers have LoRA weights lora_params = [k for k in state_dict if k.endswith(lora_sublayers)] lora_params = {k[:-7] for k in lora_params} assert lora_params == target_layer_names @torch.inference_mode() @pytest.mark.parametrize( "apply_to", (None, "lora_query", "lora_key", "lora_value", "lora_projection", "lora_mlp", "lora_head") ) def test_lora_gpt_apply_lora_forward_no_exception(apply_to): config = Config(n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1) if apply_to: setattr(config, apply_to, True) input_ids = torch.tensor([[1]]) model = LoRAGPT(config) model.eval() model(input_ids) @torch.inference_mode() @pytest.mark.parametrize("n_query_groups", (1, 2, 3, 6)) @pytest.mark.parametrize("apply_to", product((False, True), repeat=3)) def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, apply_to): keys = ("lora_query", "lora_key", "lora_value") values = apply_to apply_to = dict(zip(keys, values)) config = Config( n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1, n_query_groups=n_query_groups, **apply_to, ) model = LoRAGPT(config) merge_lora_weights(model) input_ids = torch.tensor([[1]]) model(input_ids) @torch.inference_mode() @pytest.mark.parametrize("head_size", (1, 2, 4)) @pytest.mark.parametrize("n_head", (1, 2, 3, 6, 12)) @pytest.mark.parametrize( "enable_lora", [ (False, False, True), (False, True, False), (False, True, True), (True, False, False), (True, False, True), (True, True, False), (True, True, True), ], ) def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora): C = 12 layer = LoRAQKVLinear( C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora ) x = torch.randn((1, 1, C)) a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A b = layer.lora_B.data.unsqueeze(-1) # original PyTorch conv1d function output conv1d_pytorch = F.conv1d(a, b, groups=sum(layer.enable_lora)) # custom conv1d conv1d_custom = layer.conv1d(a, b) # custom conv1d forced to split, apply and concat tensors layer.n_head = layer.n_query_groups + 1 conv1d_custom_forced = layer.conv1d(a, b) assert torch.allclose(conv1d_pytorch, conv1d_custom) assert torch.allclose(conv1d_pytorch, conv1d_custom_forced) @pytest.mark.parametrize(("rank", "expected_merged"), ((0, False), (1, True))) def test_lora_linear_weights_merged_status(rank, expected_merged): layer = LoRALinear(10, 10, r=rank) assert not layer.merged layer.merge() assert layer.merged == expected_merged @pytest.mark.parametrize( ("rank", "enable_lora", "expected_merged"), ((0, True, False), (1, True, True), (0, False, False), (1, False, False)), ) def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged): C = 10 layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora) assert not layer.merged layer.merge() assert layer.merged == expected_merged @_RunIf(min_cuda_gpus=1) def test_lora_merge_with_bitsandbytes(): if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") import bitsandbytes as bnb config = Config( n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, lora_projection=True, ) fabric = Fabric(devices=1, plugins=BitsandbytesPrecision("nf4", dtype=torch.bfloat16, ignore_modules={"lm_head"})) model = LoRAGPT(config) mark_only_lora_as_trainable(model) from bitsandbytes.optim import PagedAdamW optimizer = PagedAdamW(model.parameters(), lr=1.0) model, optimizer = fabric.setup(model, optimizer) model.train() attn_proj = model.transformer.h[0].attn.proj initial_weight = attn_proj.linear.weight.clone() initial_weight_kwargs = attn_proj.linear.weight.__dict__ # this was skipped assert model.lm_head.linear.weight.dtype is torch.float32 assert attn_proj.linear.weight.dtype is torch.uint8 # perform an update to the LoRA weights y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device)) loss = y.sum() fabric.backward(loss) optimizer.step() optimizer.zero_grad() # the weight remains unchanged (only lora A and B change) assert torch.equal(attn_proj.linear.weight, initial_weight) # calling merge() multiple times in a row should not merge multiple times merge_lora_weights(model) assert attn_proj.merged weight_after = attn_proj.linear.weight.clone() merge_lora_weights(model) merge_lora_weights(model) assert torch.equal(attn_proj.linear.weight, weight_after) # check that `W_after = W_initial + (A x B)` delta_w = attn_proj.get_lora_AB() # dequantize initial weight and sum with delta_w initial_weight_data = ( bnb.functional.dequantize_4bit(initial_weight.data, initial_weight_kwargs["quant_state"]) + delta_w ) # quantize again initial_weight_data = bnb.nn.Params4bit( initial_weight_data.to("cpu"), requires_grad=False, **initial_weight_kwargs ).to(initial_weight.device) torch.testing.assert_close(weight_after, initial_weight_data) def test_lora_gpt_init_weights(): config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_head=True) model = LoRAGPT(config) param = model.lm_head.lora_B.data assert (param == 0).all() torch.nn.init.constant_(param, 1.23) assert (param != 0).any() model.apply(model._init_weights) assert (param == 0).all() @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_lora_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() lora_model = LoRAGPT.from_name( name, **kwargs, lora_r=1, lora_query=True, lora_key=True, lora_value=True, lora_projection=True, lora_mlp=True, lora_head=True, ) keys = lora_model.load_state_dict(base_model_state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert lora_filter(k, None) @_RunIf(dynamo=True) @torch.inference_mode() def test_lora_compile(): model = LoRAGPT.from_name( "pythia-14m", n_layer=3, lora_r=8, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_key=True, lora_value=True, lora_projection=True, lora_mlp=True, lora_head=True, ) x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 model = LoRAGPT(model.config) model.set_kv_cache(2) input_pos = torch.arange(model.config.block_size) explanation = torch._dynamo.explain(model)(x, input_pos) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 @torch.inference_mode() def test_against_hf_mixtral(): device = torch.device("cpu") dtype = torch.float32 ours_config = Config.from_name( "Mixtral-8x7B-Instruct-v0.1", padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, n_expert=4, lora_r=1, lora_key=True, lora_value=True, ) T = 5 theirs_config = MixtralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, num_local_experts=ours_config.n_expert, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = MixtralForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert lora_filter(k, None) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) def test_against_hf_gemma(model_name): device = torch.device("cpu") dtype = torch.float32 T = 5 ours_config = Config.from_name( model_name, n_layer=2, n_head=16, n_embd=32, head_size=4, intermediate_size=86, lora_r=1, lora_query=True, lora_key=True, lora_value=True, ) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys for k in keys.missing_keys: assert lora_filter(k, None) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b")) def test_against_original_gemma_2(model_name): device = torch.device("cpu") dtype = torch.float32 T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Gemma2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5) @torch.inference_mode() @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it")) def test_against_original_gemma_3(model_name): device = torch.device("cpu") dtype = torch.float32 T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Gemma3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_3({}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @_RunIf(min_cuda_gpus=1) def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): if not _BITSANDBYTES_AVAILABLE: pytest.skip("BNB not available") from bitsandbytes.optim import PagedAdamW model_config = dict( block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, bias=True, lora_r=8, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, lora_projection=True, ) (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) tokenizer_mock = Mock() tokenizer_mock.return_value = tokenizer_mock tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) monkeypatch.setattr(module, "load_checkpoint", Mock()) monkeypatch.setattr(module, "merge_lora", Mock()) train_mock = Mock() train_mock.return_value = { "raw_tokens": 1000, "raw_tokens_plus_prompt_template": 1100, "raw_tokens_plus_prompt_template_and_padding": 1200, } monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]): module.setup( fake_checkpoint_dir, data=Alpaca( download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 ), out_dir=tmp_path, precision="16-true", quantize="bnb.nf4-dq", ) _, kwargs = train_mock.call_args fabric = kwargs["fabric"] model = kwargs["model"] optimizer = kwargs["optimizer"] model.transformer.wte = model.transformer.wte.half() assert isinstance(fabric.strategy.precision, BitsandbytesPrecision) assert isinstance(optimizer, _FabricOptimizer) assert isinstance(optimizer._optimizer, PagedAdamW) dtype_to_name = {"torch.uint8": set(), "torch.float16": set()} for name, layer in model.named_parameters(): name = name[len("_forward_module.") :] dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.uint8": { "transformer.h.0.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", "transformer.h.1.attn.qkv.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { "transformer.h.0.attn.qkv.lora_B", "transformer.h.0.norm_2.weight", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.ln_f.bias", "transformer.h.1.attn.qkv.lora_B", "transformer.h.1.attn.proj.linear.bias", "transformer.h.1.norm_1.weight", "transformer.h.1.attn.qkv.linear.bias", "transformer.h.1.attn.qkv.lora_A", "transformer.h.1.norm_1.bias", "transformer.h.1.norm_2.bias", "transformer.h.0.attn.proj.linear.bias", "transformer.h.0.norm_1.bias", "transformer.h.0.mlp.proj.linear.bias", "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.norm_2.bias", "transformer.ln_f.weight", "transformer.h.0.attn.qkv.lora_A", "transformer.h.1.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.attn.qkv.linear.bias", }, } assert {p.name for p in tmp_path.rglob("*.lora")} == {"lit_model.pth.lora"} state_dict = torch.load(tmp_path / "final" / "lit_model.pth.lora") assert len(state_dict) == 1 dtype_to_name = {"torch.float16": set()} for name, layer in state_dict["model"].items(): dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { "transformer.h.1.attn.qkv.lora_A", "transformer.h.0.attn.qkv.lora_A", "transformer.h.0.attn.qkv.lora_B", "transformer.h.1.attn.qkv.lora_B", } } logs = stdout.getvalue() assert "of trainable parameters: 512" in logs assert "of non-trainable parameters: 1,888" in logs @_RunIf(standalone=True, min_cuda_gpus=2) def test_lora_model_fsdp_init(): config = Config( n_layer=1, n_head=2, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=False, lora_projection=True, ) fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") fabric.launch() with fabric.init_module(empty_init=True): model = LoRAGPT(config) x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device) model = fabric.setup(model) y = model(x) assert y.shape == torch.Size([2, 8, 512]) # verify that all the parameters, buffers and other attributes aren't on `meta` device for m in model.modules(): for p_name, parameter in m.named_parameters(): assert not parameter.is_meta, f"Parameter `{p_name}` isn't materialized." for b_name, buffer in m._buffers.items(): assert not buffer.is_meta, f"Buffer `{b_name}` isn't materialized." for attr_name, attr_value in m.__dict__.items(): if isinstance(attr_value, torch.Tensor): assert not attr_value.is_meta, f"Attribute `{attr_name}` isn't materialized." def test_zero_pad_cpu_and_mocked_mps(): head_size = 64 n_head = 12 n_query_groups = 3 in_features = 128 q_size = head_size * n_head kv_size = head_size * n_query_groups out_features = q_size + 2 * kv_size enable_lora = [True, False, True] r = 4 model = LoRAQKVLinear( in_features=in_features, out_features=out_features, head_size=head_size, n_head=n_head, n_query_groups=n_query_groups, r=r, enable_lora=enable_lora, ) batch_size = 64 seq_len = 64 # embed_dim = sum of enabled qkv shapes: Q (q_size) + V (kv_size) embed_dim = q_size + kv_size x = torch.randn(batch_size, seq_len, embed_dim) result_cpu = model.zero_pad(x) with mock.patch("torch.backends.mps.is_available", return_value=True): with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device: mock_device.return_value = torch.device("mps") result_mps = model.zero_pad(x) assert result_cpu.shape == result_mps.shape, "Shape mismatch between CPU and MPS" assert torch.allclose(result_cpu, result_mps), "Tensor values mismatch between CPU and MPS" def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, lora_r=8, lora_alpha=16, lora_dropout=0.1 ) attention_1 = CausalSelfAttention(config=config, block_idx=0) # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` # and make them interleaved state_dict = deepcopy(attention_1.state_dict()) state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) @_RunIf(standalone=True, min_cuda_gpus=2) def test_parallelize_fn(): from litgpt.finetune.lora import parallelize_fn config = Config( n_layer=2, n_head=4, n_embd=32, block_size=8, vocab_size=8, lora_r=4, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, lora_projection=True, ) fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") fabric.launch() model = LoRAGPT(config) mark_only_lora_as_trainable(model) # create device mesh for data parallel device_mesh = init_device_mesh( device_type=fabric.device.type, mesh_shape=(2, 1), mesh_dim_names=("data_parallel", "tensor_parallel"), ) # test with activation checkpointing enabled (default) parallelized_model = parallelize_fn(model, device_mesh, activation_checkpointing=True) # verify the model is still functional assert parallelized_model is not None assert isinstance(parallelized_model, LoRAGPT) parallelized_model = parallelized_model.to(fabric.device) # test forward pass to ensure the parallelized model works x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device) # verify forward pass works with torch.no_grad(): output = parallelized_model(x) assert output.shape == (1, config.block_size, config.padded_vocab_size) # test with activation checkpointing disabled model_no_checkpoint = LoRAGPT(config) mark_only_lora_as_trainable(model_no_checkpoint) parallelized_model_no_checkpoint = parallelize_fn(model_no_checkpoint, device_mesh, activation_checkpointing=False) # verify the model is still functional assert parallelized_model_no_checkpoint is not None assert isinstance(parallelized_model_no_checkpoint, LoRAGPT) # test forward pass to ensure the parallelized model works parallelized_model_no_checkpoint = parallelized_model_no_checkpoint.to(fabric.device) with torch.no_grad(): output = parallelized_model_no_checkpoint(x) assert output.shape == (1, config.block_size, config.padded_vocab_size) # verify that all parameters are properly distributed (not on meta device) for mod in parallelized_model.modules(): for param_name, param in mod.named_parameters(): if param.requires_grad: # Only check trainable parameters (LoRA parameters) assert not param.is_meta, f"Parameter `{param_name}` should not be on meta device" assert param.device.type == "cuda", f"Parameter `{param_name}` should be on CUDA device" @_RunIf(standalone=True, min_cuda_gpus=2) def test_load_from_full_model_state_dict(): from litgpt.finetune.lora import parallelize_fn from litgpt.utils import load_from_full_model_state_dict config = Config( n_layer=2, n_head=4, n_embd=32, block_size=8, vocab_size=8, lora_r=4, lora_alpha=8, lora_dropout=0.1, lora_query=True, lora_value=True, lora_projection=True, lora_mlp=True, lora_head=True, ) # set up distributed environment with FSDP fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") fabric.launch() # create a reference model to get the full state dict reference_model = LoRAGPT(config) mark_only_lora_as_trainable(reference_model) # initialize the reference model with some values with torch.no_grad(): for param in reference_model.parameters(): if param.requires_grad: param.fill_(0.1) # get the full state dict (simulating a checkpoint) full_state_dict = {} for name, param in reference_model.named_parameters(): # Convert parameters to checkpoint format (what load_from_full_model_state_dict expects) if "norm" not in name and "wte" not in name and "ln_f" not in name: # For linear layers, remove .linear from the name to simulate checkpoint format checkpoint_name = name.replace(".linear.weight", ".weight").replace(".linear.bias", ".bias") else: # For norm, embedding, and layer norm layers, keep the original name checkpoint_name = name full_state_dict[checkpoint_name] = param.detach().clone() # create distributed model model = LoRAGPT(config) mark_only_lora_as_trainable(model) # set up device mesh for distributed model device_mesh = init_device_mesh( device_type=fabric.device.type, mesh_shape=(2, 1), mesh_dim_names=("data_parallel", "tensor_parallel"), ) model = parallelize_fn(model, device_mesh, activation_checkpointing=False) model = model.to(fabric.device) # test with default parameters (strict=False, cpu_offload=False) result = load_from_full_model_state_dict( model=model, full_sd=full_state_dict, device=fabric.device, strict=False, cpu_offload=False, ) # verify that the function returns the missing/unexpected keys assert hasattr(result, "missing_keys") assert hasattr(result, "unexpected_keys") # verify that parameters are loaded correctly for name, param in model.named_parameters(): if param.requires_grad: # Check that parameter is not on meta device assert not param.is_meta, f"Parameter {name} should not be on meta device" # Check that parameter is on the correct device assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device" # test with cpu_offload=True model_cpu_offload = LoRAGPT(config) mark_only_lora_as_trainable(model_cpu_offload) model_cpu_offload = parallelize_fn(model_cpu_offload, device_mesh, activation_checkpointing=False) model_cpu_offload = model_cpu_offload.to(fabric.device) result_cpu_offload = load_from_full_model_state_dict( model=model_cpu_offload, full_sd=full_state_dict, device=fabric.device, strict=False, cpu_offload=True, ) # verify that parameters are loaded correctly with CPU offload for name, param in model_cpu_offload.named_parameters(): if param.requires_grad: # Check that parameter is not on meta device assert not param.is_meta, f"Parameter {name} should not be on meta device" # With cpu_offload, parameters might be on CPU assert param.device.type in ["cpu", "cuda"], f"Parameter {name} should be on CPU or CUDA device" # test with strict=True model_strict = LoRAGPT(config) mark_only_lora_as_trainable(model_strict) model_strict = parallelize_fn(model_strict, device_mesh, activation_checkpointing=False) model_strict = model_strict.to(fabric.device) try: result_strict = load_from_full_model_state_dict( model=model_strict, full_sd=full_state_dict, device=fabric.device, strict=True, cpu_offload=False, ) # If strict loading succeeds, verify parameters for name, param in model_strict.named_parameters(): if param.requires_grad: assert not param.is_meta, f"Parameter {name} should not be on meta device" assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device" except RuntimeError as e: # strict=True might fail if there are missing keys, which is expected behavior assert "Missing key(s)" in str(e) or "Unexpected key(s)" in str(e) # test forward pass to ensure model still works after loading x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device) with torch.no_grad(): output = model(x) assert output.shape == (1, config.block_size, config.padded_vocab_size) output_cpu_offload = model_cpu_offload(x) assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size) ================================================ FILE: tests/test_merge_lora.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import shutil from contextlib import redirect_stdout from io import StringIO from pathlib import Path from unittest import mock import pytest import torch import yaml from litgpt.lora import GPT as LoRAGPT from litgpt.lora import lora_filter from litgpt.model import GPT from litgpt.scripts.merge_lora import load_lora_metadata, merge_lora @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @pytest.mark.parametrize( ("pretrained_dtype", "lora_dtype"), [(None, None), (torch.float16, torch.float32), (torch.float16, torch.bfloat16)] ) def test_merge_lora(tmp_path, fake_checkpoint_dir, pretrained_dtype, lora_dtype): pretrained_checkpoint_dir = tmp_path / "pretrained" lora_checkpoint_dir = tmp_path / "lora" shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir) shutil.copytree(fake_checkpoint_dir, lora_checkpoint_dir) (lora_checkpoint_dir / "lit_model.pth").unlink() # should not already exist shutil.rmtree(tmp_path / "checkpoints") # Create a fake pretrained checkpoint config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16) with open(pretrained_checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp: yaml.dump(config, fp) base_model = GPT.from_name("pythia-14m", **config).to(dtype=pretrained_dtype) state_dict = base_model.state_dict() assert len(state_dict) == 40 torch.save(state_dict, pretrained_checkpoint_dir / "lit_model.pth") # Create a fake LoRA checkpoint lora_kwargs = dict(lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_query=True, lora_value=True) lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs).to(dtype=lora_dtype) state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)} assert len(state_dict) == 6 torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth.lora") hparams = dict(checkpoint_dir=str(pretrained_checkpoint_dir), **lora_kwargs) with open(lora_checkpoint_dir / "hyperparameters.yaml", "w", encoding="utf-8") as file: yaml.dump(hparams, file) shutil.copyfile(pretrained_checkpoint_dir / "model_config.yaml", lora_checkpoint_dir / "model_config.yaml") assert set(os.listdir(tmp_path)) == {"lora", "pretrained"} merge_lora(lora_checkpoint_dir) assert set(os.listdir(tmp_path)) == {"lora", "pretrained"} assert set(os.listdir(lora_checkpoint_dir)) == { "model_config.yaml", "lit_model.pth", "lit_model.pth.lora", "tokenizer.json", "tokenizer_config.json", "hyperparameters.yaml", } # Assert that the merged weights can be loaded back into the base model merged = torch.load(lora_checkpoint_dir / "lit_model.pth") keys = base_model.load_state_dict(merged, strict=True) assert not keys.missing_keys assert not keys.unexpected_keys # Attempt to merge again stdout = StringIO() with redirect_stdout(stdout): merge_lora(lora_checkpoint_dir) assert "LoRA weights have already been merged" in stdout.getvalue() def test_load_lora_metadata(fake_checkpoint_dir): assert not (fake_checkpoint_dir / "hyperparameters.yaml").is_file() with pytest.raises(FileNotFoundError, match="missing a `hyperparameters.yaml` file"): load_lora_metadata(fake_checkpoint_dir) hparams = dict(precision="bf16-mixed", checkpoint_dir="checkpoints/meta-llama/Llama-2-7b", lora_r=8, lora_alpha=16) with open(fake_checkpoint_dir / "hyperparameters.yaml", "w", encoding="utf-8") as file: yaml.dump(hparams, file) lora_args, pretrained_dir, precision = load_lora_metadata(fake_checkpoint_dir) assert lora_args == dict(lora_r=8, lora_alpha=16) assert pretrained_dir == Path("checkpoints/meta-llama/Llama-2-7b") assert precision == "bf16-mixed" ================================================ FILE: tests/test_model.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from copy import deepcopy from functools import partial from unittest import mock import pytest import torch from lightning import Fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.init import _materialize_meta_tensors from torch._dynamo.backends import debugging from torch.backends.cuda import ( SDPAParams, SDPBackend, can_use_efficient_attention, can_use_flash_attention, flash_sdp_enabled, math_sdp_enabled, mem_efficient_sdp_enabled, ) from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.falcon import FalconConfig, FalconForCausalLM from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.gemma3 import Gemma3Config, Gemma3ForCausalLM, Gemma3ForConditionalGeneration, Gemma3TextConfig from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM from transformers.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM import litgpt.config as config_module from litgpt import GPT, Config from litgpt.model import CausalSelfAttention, batched_index_copy_ from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_gpt_neox, copy_weights_hf_llama, copy_weights_olmo2, copy_weights_phi, copy_weights_qwen_2_5, copy_weights_qwen_3, ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf @torch.inference_mode() @pytest.mark.parametrize("rotary_pct", (0.25, 1)) @pytest.mark.parametrize("batch_size", (1, 3)) @pytest.mark.parametrize("n_embd", (16, 32)) @pytest.mark.parametrize("parallel_residual", (False, True)) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None: torch.set_default_dtype(dtype) ours_config = Config( block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=n_embd, rotary_percentage=rotary_pct, parallel_residual=parallel_residual, ) assert ours_config.padded_vocab_size == 512 theirs_config = GPTNeoXConfig( hidden_act="gelu", hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, initializer_range=0.02, intermediate_size=ours_config.intermediate_size, layer_norm_eps=ours_config.norm_eps, max_position_embeddings=ours_config.block_size, rotary_emb_base=10000, rotary_pct=ours_config.rotary_percentage, vocab_size=ours_config.padded_vocab_size, use_parallel_residual=ours_config.parallel_residual, attn_implementation="eager", ) state_dict = {} theirs_model = GPTNeoXForCausalLM(theirs_config).to(device) # load the hf initialization into our model copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict()) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) token_sample = torch.randint( 0, ours_config.padded_vocab_size, size=(batch_size, ours_config.block_size), dtype=torch.int64, device=device ) theirs = theirs_model(token_sample)["logits"] ours = ours_model(token_sample) torch.testing.assert_close(ours, theirs, rtol=1e-2, atol=1e-2) @torch.inference_mode() @pytest.mark.parametrize( "kwargs", [ dict(name="falcon-180B", n_layer=2, n_head=8, n_query_groups=4, n_embd=32), dict(name="falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32), ], ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_hf_falcon(kwargs, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name(**kwargs) theirs_config = FalconConfig( hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_kv_heads=ours_config.n_query_groups, num_hidden_layers=ours_config.n_layer, parallel_attn=ours_config.parallel_residual, vocab_size=ours_config.padded_vocab_size, bias=ours_config.bias, new_decoder_architecture=True, ) theirs_model = FalconForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_falcon(ours_config, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_open_llama_3b(device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) T = 5 theirs_config = LlamaConfig( hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = LlamaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( "ours_kwargs", [ {"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}, {"name": "Llama-3-8B"}, {"name": "Llama-3-8B-Instruct"}, {"name": "Llama-3.1-405B", "n_query_groups": 4}, {"name": "Llama-3.1-8B"}, {"name": "Llama-3.1-8B-Instruct"}, {"name": "Llama-3.2-1B"}, {"name": "Llama-3.2-3B"}, {"name": "Llama-3.3-70B-Instruct"}, {"name": "R1-Distill-Llama-8B"}, {"name": "R1-Distill-Llama-70B"}, ], ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs ) T = 5 theirs_config = LlamaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = LlamaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)], ), ], ) def test_against_hf_phi(model_name, device, dtype): from transformers.models.phi.configuration_phi import PhiConfig from transformers.models.phi.modeling_phi import PhiForCausalLM torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5 ) T = 5 theirs_config = PhiConfig( vocab_size=ours_config.padded_vocab_size, max_position_embeddings=ours_config.block_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, partial_rotary_factor=ours_config.rotary_percentage, torch_dtype=dtype, ) theirs_model = PhiForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( "model_name", ( "Phi-3-mini-4k-instruct", "Phi-3-mini-128k-instruct", "Phi-3.5-mini-instruct", "phi-4", "Phi-4-mini-instruct", "Phi-4-reasoning", "Phi-4-mini-reasoning", ), ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)], ), ], ) def test_against_hf_phi_3(model_name, device, dtype): from transformers.models.phi3.configuration_phi3 import Phi3Config from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_query_groups=4, n_embd=256, ) T = 5 theirs_config = Phi3Config( attention_bias=ours_config.bias, head_dim=ours_config.head_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, num_key_value_heads=ours_config.n_query_groups, pad_token_id=ours_config.padded_vocab_size - 1, partial_rotary_factor=ours_config.rotary_percentage, rms_norm_eps=ours_config.norm_eps, rope_theta=ours_config.rope_base, torch_dtype=dtype, vocab_size=ours_config.padded_vocab_size, ) theirs_model = Phi3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) @pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mistral-7B-v0.1"]) def test_against_mistral_hf_models(device, dtype, model_name): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, padded_vocab_size=10000, block_size=T, sliding_window_size=T // 2, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, ) theirs_config = MistralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attn_implementation="eager", sliding_window=ours_config.sliding_window_size, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = MistralForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_mathstral_hf_models(device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( "Mathstral-7B-v0.1", padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, ) T = 5 theirs_config = MistralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = MistralForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1")) def test_against_hf_mixtral(model_name): device = torch.device("cpu") dtype = torch.float32 ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_embd=32, n_head=8, n_query_groups=2, intermediate_size=86, n_expert=4, ) T = 5 theirs_config = MixtralConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, num_local_experts=ours_config.n_expert, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = MixtralForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_olmo(model_name, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, ) T = 5 theirs_config = OlmoConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_key_value_heads=ours_config.n_query_groups, max_positional_embeddings=T, attention_bias=ours_config.bias, rope_theta=ours_config.rope_base, tie_word_embeddings=(model_name == "OLMo-1B-hf"), ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = OlmoForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_olmo2(model_name, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, n_query_groups=2, intermediate_size=86, ) T = 5 theirs_config = Olmo2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_key_value_heads=ours_config.n_query_groups, max_positional_embeddings=T, rms_norm_eps=ours_config.norm_eps, attention_bias=ours_config.bias, rope_theta=ours_config.rope_base, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = Olmo2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_stablelm_zephyr_3b(device, dtype): torch.set_default_dtype(dtype) T = 5 ours_config = Config.from_name("stablelm-zephyr-3b", n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = AutoConfig.from_pretrained( "stabilityai/stablelm-zephyr-3b", trust_remote_code=True, num_hidden_layers=ours_config.n_layer, num_attention_heads=ours_config.n_head, num_key_value_heads=ours_config.n_head, hidden_size=ours_config.n_embd, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, torch_dtype=dtype, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma(model_name, device, dtype): torch.set_default_dtype(dtype) T = 5 ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = GemmaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_2(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_logit_softcapping=ours_config.attention_logit_softcapping, final_logit_softcapping=ours_config.final_logit_softcapping, initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, ) theirs_model = Gemma2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"]) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_gemma_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, rope_scaling={"factor": 8.0, "rope_type": "linear"}, rope_local_base_freq=ours_config.rope_local_base_freq, ) theirs_model = Gemma3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_gemma_3({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @torch.inference_mode() @pytest.mark.parametrize("model_name", ["gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"]) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_multimodal_gemma_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, sliding_window_size=T // 2, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Gemma3Config( Gemma3TextConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, sliding_window=ours_config.sliding_window_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, tie_word_embeddings=True, hidden_act="gelu_pytorch_tanh", attn_implementation="eager", query_pre_attn_scalar=ours_config.attention_scores_scalar, rope_scaling={"factor": 8.0, "rope_type": "linear"}, rope_local_base_freq=ours_config.rope_local_base_freq, ) ) theirs_model = Gemma3ForConditionalGeneration(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_gemma_3({}, state_dict, theirs_state_dict, config=ours_config) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @torch.inference_mode() @pytest.mark.parametrize( "model_name", ["Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B"] ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_qwen_2_5(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Qwen2Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.attn_bias, tie_word_embeddings=True, ) theirs_model = Qwen2ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} copy_weights_qwen_2_5(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( "model_name", [ "Qwen3-0.6B", "Qwen3-8B", "Qwen3-4B-Base", "Qwen3-14B-Base", "Qwen3-32B", "Qwen3-4B-Thinking-2507", "Qwen3-4B-Instruct-2507", ], ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_qwen_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, ) theirs_config = Qwen3Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=ours_config.block_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, tie_word_embeddings=False, ) theirs_model = Qwen3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( "model_name", ["Qwen3-30B-A3B", "Qwen3-235B-A22B", "Qwen3-235B-A22B-Thinking-2507", "Qwen3-235B-A22B-Instruct-2507"] ) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_qwen_3_moe(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 ours_config = Config.from_name( model_name, block_size=T, n_layer=2, n_head=16, n_embd=32, intermediate_size=86, moe_intermediate_size=20, n_expert=4, n_expert_per_token=2, ) theirs_config = Qwen3MoeConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, head_dim=ours_config.head_size, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, moe_intermediate_size=ours_config.moe_intermediate_size, max_position_embeddings=ours_config.block_size, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, tie_word_embeddings=False, num_experts=ours_config.n_expert, num_experts_per_tok=ours_config.n_expert_per_token, norm_topk_prob=True, ) theirs_model = Qwen3MoeForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_salamandra(model_name, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, n_query_groups=2, intermediate_size=86, ) T = 5 theirs_config = LlamaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = LlamaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_original_smollm2(model_name, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, n_query_groups=2, intermediate_size=86, ) T = 5 theirs_config = LlamaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = LlamaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base")) @pytest.mark.parametrize( ("device", "dtype"), [ (torch.device("cpu"), torch.float32), pytest.param( torch.device("cuda"), torch.float16, marks=[ # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input # is slightly different pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1), ], ), ], ) def test_against_hf_falcon3(model_name, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, n_query_groups=2, intermediate_size=86, ) T = 5 theirs_config = LlamaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, num_attention_heads=ours_config.n_head, num_hidden_layers=ours_config.n_layer, intermediate_size=ours_config.intermediate_size, max_position_embeddings=T, rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, attention_bias=ours_config.bias, ) assert ours_config.intermediate_size == theirs_config.intermediate_size theirs_model = LlamaForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) @_RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): model = GPT.from_name("pythia-14m", n_layer=3) x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 model = GPT(model.config) model.set_kv_cache(2) input_pos = torch.arange(model.config.block_size) explanation = torch._dynamo.explain(model)(x, input_pos) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 @torch.inference_mode() @pytest.mark.parametrize( "max_seq_length", (25, pytest.param(23, marks=pytest.mark.xfail(raises=IndexError, strict=True))) ) @pytest.mark.flaky(reruns=5) def test_kv_cache(max_seq_length): config = Config(block_size=25, padded_vocab_size=5, n_layer=2, n_head=2, n_embd=8) model = GPT(config) idx = torch.randint(0, model.config.padded_vocab_size, (1, 5)) max_new_tokens = 20 model.max_seq_length = max_seq_length model.set_kv_cache(1) def generate(logits): logits = logits[:, -1:] probs = torch.nn.functional.softmax(logits, dim=-1) return torch.argmax(probs).unsqueeze(0).unsqueeze(0) x_no_cache = idx x_cache = idx input_pos = torch.arange(0, 5) for _ in range(max_new_tokens): logits_no_cache = model(x_no_cache[:, -max_seq_length:]) out_no_cache = generate(logits_no_cache) logits_cache = model(x_cache, input_pos) out_cache = generate(logits_cache) torch.testing.assert_close(out_no_cache, out_cache, rtol=0, atol=0) x_no_cache = torch.cat((x_no_cache, out_no_cache), dim=1) x_cache = out_cache input_pos = input_pos[-1:] + 1 @torch.inference_mode() def test_model_kv_cache_amp(): config = Config.from_name("pythia-14m", n_layer=2) model = GPT(config) encoded = torch.arange(45) model.set_kv_cache(batch_size=1) with torch.autocast("cpu", torch.bfloat16): output = model(encoded.unsqueeze(0), encoded) assert output.dtype is torch.bfloat16 @pytest.mark.parametrize("model_name", ["pythia-14m", "gemma-3-1b-it"]) def test_rope_cache_length(model_name): config = Config.from_name(model_name, n_layer=2) model = GPT(config) model.max_seq_length = 128 rope_len = model.rope_cache_length() assert rope_len == config.rope_n_elem # Verify it works with set_kv_cache model.set_kv_cache(batch_size=1) assert model.transformer.h[0].attn.kv_cache is not None # https://github.com/pytorch/pytorch/blob/ad3572a5d/torch/testing/_internal/common_cuda.py#L31-L34 SUPPORTS_FLASH_ATTENTION = ( torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) and not _IS_WINDOWS ) @_RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs]) @torch.inference_mode() def test_sdpa_choice(config): if config["name"].startswith("Gemma-2-"): pytest.skip("Gemma 2 doesn't support SDPA") torch.set_default_dtype(torch.float16) def assert_sdpa_backend(original_fn, q, k, v, mask): # SDPAParams gained an additional argument in PyTorch 2.5 args = [] if hasattr(SDPAParams, "enable_gqa"): args.append(False) params = SDPAParams(q, k, v, mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: assert flash_sdp_enabled(), "flash_sdp_enabled() is False" if config.sliding_window_size is None: assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False" elif expected is SDPBackend.EFFICIENT_ATTENTION: assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False" assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False" elif expected is SDPBackend.MATH: assert math_sdp_enabled(), "math_sdp_enabled() is False" else: raise NotImplementedError return original_fn(q, k, v, mask) config["n_layer"] = 1 config = config_module.Config(**config) try: with torch.device("cuda"): model = GPT(config) x = torch.randint(0, 10, (2, 16), dtype=torch.int32) except torch.cuda.OutOfMemoryError: # best effort, if the GPU can load it pytest.xfail() for h in model.transformer.h: h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) if SUPPORTS_FLASH_ATTENTION: expected = SDPBackend.FLASH_ATTENTION with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): model(x) expected = SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 else SDPBackend.MATH with torch.backends.cuda.sdp_kernel(enable_flash=False): model(x) @_RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs]) @torch.inference_mode() def test_sdpa_choice_kv_cache(config): torch.set_default_dtype(torch.float16) def assert_sdpa_backend(original_fn, q, k, v, mask): # SDPAParams gained an additional argument in PyTorch 2.5 args = [] if hasattr(SDPAParams, "enable_gqa"): args.append(False) params = SDPAParams(q, k, v, mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: assert flash_sdp_enabled() assert can_use_flash_attention(params, True) elif expected is SDPBackend.EFFICIENT_ATTENTION: assert mem_efficient_sdp_enabled() assert can_use_efficient_attention(params, True) elif expected is SDPBackend.MATH: assert math_sdp_enabled() else: raise NotImplementedError return original_fn(q, k, v, mask) config["n_layer"] = 1 config = config_module.Config(**config) try: with torch.device("cuda"): model = GPT(config) model.max_seq_length = 1 model.set_kv_cache(2) x = torch.randint(0, 10, (2, 1), dtype=torch.int32) input_pos = torch.tensor([0], dtype=torch.long) except torch.cuda.OutOfMemoryError: # best effort, if the GPU can load it pytest.xfail() for h in model.transformer.h: h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) if SUPPORTS_FLASH_ATTENTION: # flash attention does not support an attention mask expected = SDPBackend.MATH with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): model(x, input_pos) expected = ( SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 and config.n_query_groups != 1 else SDPBackend.MATH ) with torch.backends.cuda.sdp_kernel(enable_flash=False): model(x, input_pos) @_RunIf(min_cuda_gpus=2, standalone=True) def test_rope_init_under_fsdp(): """Check that the rope cache is properly initialized""" fabric = Fabric(devices=2, strategy="fsdp", accelerator="cuda") fabric.launch() with fabric.init_module(empty_init=True): model = GPT.from_name("pythia-14m", n_layer=1) assert model.cos.device.type == "meta" assert model.sin.device.type == "meta" model = fabric.setup(model) assert model.cos.device.type == "cuda" assert model.sin.device.type == "cuda" cos, sin = model.rope_cache(device=fabric.device) torch.testing.assert_close(model.cos, cos) torch.testing.assert_close(model.sin, sin) @_RunIf(min_cuda_gpus=1) def test_reset_parameters_device(): with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=1) _materialize_meta_tensors(model, torch.device("cuda")) model.reset_parameters() assert model.cos.device.type == "cuda" def test_batched_index_copy_modes(): # Mock the torch.backends.mps.is_available() function to simulate MPS availability with mock.patch("torch.backends.mps.is_available", return_value=True): # Mock the device type to simulate the "mps" device with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device: mock_device.return_value = torch.device("mps") # Test case when idx.dim() == 1 t_original_1 = torch.randn(3, 5) dim_1 = 0 idx_1 = torch.tensor([0, 2]) val_1 = torch.randn(2, 5) t1_cpu = t_original_1.clone() t1_mps = t_original_1.clone() # Perform the index copy on CPU batched_index_copy_(t1_cpu, dim_1, idx_1, val_1) # Simulate the MPS index copy idx_1_mps = idx_1 val_1_mps = val_1 batched_index_copy_(t1_mps, dim_1, idx_1_mps, val_1_mps) assert torch.allclose(t1_cpu, t1_mps), "Mismatch with idx.dim() == 1 on mocked MPS" # Test case when idx.dim() == 2 t_original_2 = torch.randn(2, 5, 4) dim_2 = 1 idx_2 = torch.tensor([[0, 2], [1, 3]]) val_2 = torch.randn(2, 2, 4) t2_cpu = t_original_2.clone() t2_mps = t_original_2.clone() # Perform the index copy on CPU batched_index_copy_(t2_cpu, dim_2, idx_2, val_2) # Simulate the MPS index copy idx_2_mps = idx_2 val_2_mps = val_2 batched_index_copy_(t2_mps, dim_2, idx_2_mps, val_2_mps) assert torch.allclose(t2_cpu, t2_mps), "Mismatch with idx.dim() == 2 on mocked MPS" # Additional test with negative dimension t_original_3 = torch.randn(2, 3, 4) dim_3 = -2 idx_3 = torch.tensor([[0, 1], [1, 2]]) val_3 = torch.randn(2, 2, 4) t3_cpu = t_original_3.clone() t3_mps = t_original_3.clone() # Perform the index copy on CPU batched_index_copy_(t3_cpu, dim_3, idx_3, val_3) # Simulate the MPS index copy idx_3_mps = idx_3 val_3_mps = val_3 batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, ) attention_1 = CausalSelfAttention(config=config, block_idx=0) # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` # and make them interleaved state_dict = deepcopy(attention_1.state_dict()) state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) @pytest.mark.parametrize("n_query_groups", (1, 2, 4, 8)) @torch.inference_mode() def test_kv_cache_buffer_shape(n_query_groups): batch_size = 3 max_seq_length = 23 config = Config( block_size=25, padded_vocab_size=5, n_layer=2, n_head=8, n_embd=16, n_query_groups=n_query_groups, ) model = GPT(config) model.max_seq_length = max_seq_length model.set_kv_cache(batch_size) required_shape = (batch_size, n_query_groups, max_seq_length, config.head_size) for block in model.transformer.h: kv_cache = block.attn.kv_cache assert kv_cache is not None assert kv_cache.k.shape == required_shape assert kv_cache.v.shape == required_shape @pytest.mark.parametrize(("rotary_percentage", "final_dim"), ((0.75, 3), (0.25, 2))) @torch.inference_mode() def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim): batch_size = 3 config = Config( block_size=25, padded_vocab_size=5, n_layer=2, n_head=4, n_embd=16, rotary_percentage=rotary_percentage, ) model = GPT(config) required_shape = (config.block_size, final_dim) assert model.cos.shape == required_shape assert model.sin.shape == required_shape def test_forward_with_without_input_pos_maxp1(): batch_size = 3 config = Config( block_size=25, padded_vocab_size=5, n_layer=2, n_head=8, n_embd=16, ) model = GPT(config) model.set_kv_cache(batch_size) idx = torch.randint(0, config.padded_vocab_size, (1, 10)) input_pos = torch.arange(1, 11) input_pos_maxp1 = 11 logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) logits_no_maxp1 = model(idx, input_pos) torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1) ================================================ FILE: tests/test_multihead_latent_attention.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import pytest import torch from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from litgpt import Config from litgpt.model import MultiheadLatentAttention @torch.inference_mode() def test_multihead_latent_attention_kv_cache(): """Test KV cache functionality""" config = Config( block_size=32, n_embd=64, n_head=4, n_query_groups=4, head_size=16, latent_attention={ "q_lora_rank": 32, "kv_lora_rank": 16, "qk_rope_head_dim": 8, "qk_nope_head_dim": 8, "v_head_dim": 16, }, ) mla = MultiheadLatentAttention(config, block_idx=0) # Build KV cache kv_cache = mla.build_kv_cache(batch_size=2, max_seq_length=32, device=torch.device("cpu"), dtype=torch.float32) # Check cache shapes assert kv_cache.k.shape == (2, config.n_head, 32, config.qk_head_dim) assert kv_cache.v.shape == (2, config.n_head, 32, config.v_head_dim) @torch.inference_mode() def test_multihead_latent_attention_with_mask(): """Test attention with causal mask""" config = Config( n_embd=64, n_head=4, n_query_groups=4, head_size=16, latent_attention={ "q_lora_rank": 32, "kv_lora_rank": 16, "qk_rope_head_dim": 8, "qk_nope_head_dim": 8, "v_head_dim": 16, }, ) mla = MultiheadLatentAttention(config, block_idx=0) batch_size, seq_len = 1, 8 x = torch.randn(batch_size, seq_len, config.n_embd) cos = torch.randn(1, seq_len, config.qk_rope_head_dim) sin = torch.randn(1, seq_len, config.qk_rope_head_dim) # Create causal mask mask = torch.ones(seq_len, seq_len, dtype=x.dtype).triu(diagonal=1) mask.masked_fill_(mask.bool(), float("-inf")) mask = mask.view(1, 1, seq_len, seq_len) # Forward pass with mask output = mla(x, cos, sin, mask=mask) assert output.shape == (batch_size, seq_len, config.n_embd) @torch.inference_mode() @pytest.mark.parametrize("batch_size", (1, 2)) @pytest.mark.parametrize("seq_len", (8, 16)) @pytest.mark.parametrize("device", [torch.device("cpu")]) def test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device): """Test MLA litgpt vs hf""" config_litgpt = Config( n_embd=64, n_head=4, n_query_groups=4, head_size=16, norm_eps=1e-6, bias=False, latent_attention={ "q_lora_rank": 32, "kv_lora_rank": 16, "qk_rope_head_dim": 8, "qk_nope_head_dim": 8, "v_head_dim": 16, }, ) config_hf = DeepseekV3Config( padded_vocab_size=10000, num_hidden_layers=1, vocab_size=10000, hidden_size=64, num_attention_heads=4, num_key_value_heads=4, q_lora_rank=32, kv_lora_rank=16, qk_rope_head_dim=8, qk_nope_head_dim=8, v_head_dim=16, rope_interleave=False, ) mla_litgpt = MultiheadLatentAttention(config_litgpt, block_idx=0).to(device) model_hf = DeepseekV3ForCausalLM(config_hf).to(device) mla_hf = model_hf.model.layers[0].self_attn mla_litgpt.eval() mla_hf.eval() sync_weights(mla_litgpt, mla_hf) hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) # Prepare RoPE sin/cos tables rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"] cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) causal_mask = torch.triu( torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1 ) attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) # Run forward passes output_litgpt = mla_litgpt(hidden_states, cos, sin) output_hf = mla_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)[0] assert torch.allclose(output_litgpt, output_hf, atol=1e-5) def sync_weights(litgpt_model, hf_model): """Copies weights from lit-gpt model to HF model.""" print("Synchronizing weights...") with torch.no_grad(): hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight) hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight) hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight) hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight) hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight) hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight) hf_model.o_proj.weight.copy_(litgpt_model.proj.weight) print("Synchronization complete.") ================================================ FILE: tests/test_pretrain.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout from io import StringIO from unittest import mock from unittest.mock import ANY, Mock import pytest import torch from lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy from torch.utils.data import DataLoader from litgpt import pretrain from litgpt.args import EvalArgs, TrainArgs from litgpt.config import Config from litgpt.pretrain import initialize_weights from litgpt.utils import _RunIf @_RunIf(min_cuda_gpus=1, standalone=True) @mock.patch("litgpt.pretrain.save_hyperparameters") def test_optimizer_args(_, tmp_path): model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8) dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset) pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader)) for i in ("AdamW", "SGD", "RMSprop"): pretrain.setup( "pythia-14m", devices=1, optimizer="RMSprop", model_config=model_config, out_dir=tmp_path, train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1, final_validation=False), ) @_RunIf(min_cuda_gpus=2, standalone=True) # If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise # the CLI would capture pytest args, but unfortunately patching would mess with subprocess # launching, so we need to mock `save_hyperparameters()` @mock.patch("litgpt.pretrain.save_hyperparameters") # todo: it expects exactly 2 GPUs and has strange failing for validated 4 # GPUs, so we temporarily mark it as xfail @pytest.mark.xfail(condition=torch.cuda.device_count() != 2, reason="This test is flaky, expects exactly 2 GPUs") def test_pretrain(_, tmp_path): model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset) pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader)) out_dir = tmp_path / "out" stdout = StringIO() with redirect_stdout(stdout): pretrain.setup( "pythia-14m", devices=2, model_config=model_config, out_dir=out_dir, train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1, final_validation=False), ) if torch.distributed.get_rank() == 0: # tmp_path is not the same across all ranks, run assert only on rank 0 out_dir_contents = set(os.listdir(out_dir)) checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004", "final"} assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "model_config.yaml"} assert (out_dir / "logs" / "tensorboard" / "version_0").is_dir() # logs only appear on rank 0 logs = stdout.getvalue() assert logs.count("(step)") == 4 assert logs.count("val loss") == 4 assert "Total parameters: 1,888" in logs torch.distributed.barrier() @_RunIf(min_cuda_gpus=2, standalone=True) @mock.patch("litgpt.pretrain.L.Fabric.load_raw") # See comment in `test_pretrain` why we need to mock `save_hyperparameters()` @mock.patch("litgpt.pretrain.save_hyperparameters") def test_initial_checkpoint_dir(_, load_mock, tmp_path): model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset) pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader)) pretrain.fit = Mock() pretrain.setup( "pythia-14m", initial_checkpoint_dir=tmp_path, devices=torch.cuda.device_count(), model_config=model_config, out_dir=tmp_path, ) load_mock.assert_called_once_with(tmp_path / "lit_model.pth", ANY) @pytest.mark.parametrize(("strategy", "expected"), [(SingleDeviceStrategy, True), (FSDPStrategy, False)]) def test_initialize_weights(strategy, expected): fabric_mock = Mock() fabric_mock.strategy = Mock(spec=strategy) class Child(torch.nn.Module): pass class Parent(torch.nn.Module): def __init__(self): super().__init__() self.child = Child() model = Parent() model.reset_parameters = Mock() model.child.reset_parameters = Mock() initialize_weights(fabric_mock, model, n_layer=2, n_embd=8) assert model.reset_parameters.call_count == int(expected) assert model.child.reset_parameters.call_count == int(expected) ================================================ FILE: tests/test_prompts.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from typing import Optional import pytest import yaml import litgpt.config from litgpt import Config from litgpt.prompts import ( Alpaca, Default, Llama3, Phi3, PromptStyle, has_prompt_style, load_prompt_style, prompt_styles, save_prompt_style, ) def test_default_prompt_style(mock_tokenizer): prompt_style = Default() prompt = "This is a test prompt." assert prompt_style.apply(prompt) == prompt assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],) @pytest.mark.parametrize("sys_prompt", [None, "You are a helpful coding assistant."]) def test_sys_prompt(mock_tokenizer, sys_prompt: Optional[str]): prompt_style = Phi3() prompt = "This is a test prompt." default_sys_prompt = "You are a helpful assistant." response = f"<|system|>\n{sys_prompt or default_sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n" assert prompt_style.apply(prompt, sys_prompt=sys_prompt) == response assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],) @pytest.mark.parametrize("sys_prompt", [None, "You are a helpful coding assistant."]) def test_sys_prompt_with_kwargs(mock_tokenizer, sys_prompt: Optional[str]): prompt_style = Phi3() prompt = "This is a test prompt." default_sys_prompt = "You are a helpful assistant." response = f"<|system|>\n{sys_prompt or default_sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n" assert prompt_style.apply(prompt, sys_prompt=sys_prompt, test=1) == response assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],) def test_prompt_style_from_name(): for style_name in prompt_styles: assert isinstance(PromptStyle.from_name(style_name), prompt_styles[style_name]) def test_prompt_style_from_config(): model_names = [ "stablelm-tuned-alpha-3b", "stablelm-tuned-alpha-7b", "stablelm-zephyr-3b", "stablecode-instruct-alpha-3b", "falcon-7b-instruct", "falcon-40b-instruct", "Llama-2-7b-chat-hf", "Llama-2-13b-chat-hf", "Llama-2-70b-chat-hf", "Llama-3-8B-Instruct", "Llama-3-70B-Instruct", "Llama-3.1-405B-Instruct", "Gemma-2b-it", "Gemma-7b-it", "FreeWilly2", "CodeLlama-7b-Instruct-hf", "CodeLlama-13b-Instruct-hf", "CodeLlama-34b-Instruct-hf", "CodeLlama-70b-Instruct-hf", "phi-1_5", "phi-2", "Phi-3-mini-4k-instruct", "Mistral-7B-Instruct-v0.1", "Mistral-7B-Instruct-v0.2", "tiny-llama-1.1b-chat", "Llama-2-7b-chat-hf-function-calling-v2", ] for c in litgpt.config.platypus: model_names.append(c["name"]) for model_name in model_names: # by asserting the returned style is not the Default, we show that at least one of the regex patterns matched assert not isinstance(PromptStyle.from_config(Config.from_name(model_name)), Default) def test_apply_prompts(): prompt = "Is a coconut a nut or a fruit?" inp = "Optional input" for style in prompt_styles.values(): output = style().apply(prompt, input=inp) assert prompt in output if isinstance(style, Alpaca): assert inp in output class CustomPromptStyle(PromptStyle): def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str: return prompt def test_save_load_prompt_style(tmp_path): # Save and load a built-in style checkpoint_dir = tmp_path / "checkpoint" checkpoint_dir.mkdir() assert not has_prompt_style(checkpoint_dir) save_prompt_style("alpaca", checkpoint_dir) assert has_prompt_style(checkpoint_dir) with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file: contents = yaml.safe_load(file) assert contents == {"class_path": "litgpt.prompts.Alpaca"} loaded = load_prompt_style(checkpoint_dir) assert isinstance(loaded, Alpaca) # Save a custom style checkpoint_dir = tmp_path / "custom" checkpoint_dir.mkdir() save_prompt_style(CustomPromptStyle(), checkpoint_dir) with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file: contents = yaml.safe_load(file) assert contents == {"class_path": "test_prompts.CustomPromptStyle"} loaded = load_prompt_style(checkpoint_dir) assert isinstance(loaded, CustomPromptStyle) def test_multiturn_prompt(): prompt = "What is the capital of France?" msgs = [{"role": "user", "content": prompt}] style = Llama3() simple_output = style.apply(prompt) multiturn_output = style.apply(msgs) assert simple_output == multiturn_output # override system prompt msgs = [{"role": "system", "content": "You are not a helpful assistant."}, {"role": "user", "content": prompt}] with_system_multiturn_output = style.apply(msgs) assert "You are not a helpful assistant." in with_system_multiturn_output # use default system prompt msgs = [ {"role": "user", "content": prompt}, ] wo_system_multiturn_output = style.apply(msgs) assert "You are a helpful assistant." in wo_system_multiturn_output # Longer turn msgs = [ {"role": "system", "content": "You are a helpful AI assistant for travel tips and recommendations"}, {"role": "user", "content": "What is France's capital?"}, {"role": "assistant", "content": "Bonjour! The capital of France is Paris!"}, {"role": "user", "content": "What can I do there?"}, ] multiturn_output = style.apply(msgs) assert ( multiturn_output == """<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|> What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|> What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ ) # Longer list without "system" msgs = [ {"role": "user", "content": "What is France's capital?"}, {"role": "assistant", "content": "Bonjour! The capital of France is Paris!"}, {"role": "user", "content": "What can I do there?"}, ] multiturn_output = style.apply(msgs) assert ( multiturn_output == """<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|> Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|> What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ ) # {random} string format shouldn't lead to key error content = "this is {random} {system} {user}" msgs = [{"role": "user", "content": content}] output = style.apply(msgs) simple_output = style.apply(content) assert output == simple_output ================================================ FILE: tests/test_readme.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import platform import subprocess import sys import threading import time from pathlib import Path from unittest import mock import pytest import requests from urllib3.exceptions import MaxRetryError from litgpt.utils import _RunIf, kill_process_tree REPO_ID = Path("EleutherAI/pythia-14m") CUSTOM_TEXTS_DIR = Path("custom_texts") def run_command(command): try: result = subprocess.run(command, capture_output=True, text=True, check=True) return result.stdout except subprocess.CalledProcessError as e: error_message = ( f"Command '{' '.join(command)}' failed with exit status {e.returncode}\n" f"Output:\n{e.stdout}\n" f"Error:\n{e.stderr}" ) # You can either print the message, log it, or raise an exception with it print(error_message) raise RuntimeError(error_message) from None def _wait_and_check_response(waiting: int = 30): response_status_code, err = -1, None for _ in range(waiting): try: response = requests.get("http://127.0.0.1:8000", timeout=1) response_status_code = response.status_code except (MaxRetryError, requests.exceptions.ConnectionError) as ex: response_status_code = -1 err = str(ex) if response_status_code == 200: break time.sleep(1) assert response_status_code == 200, "Server did not respond as expected. Error: {err}" @pytest.mark.dependency() @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_download_model(): repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI command = ["litgpt", "download", str(repo_id)] output = run_command(command) s = Path("checkpoints") / repo_id assert f"Saving converted checkpoint to {str(s)}" in output assert ("checkpoints" / REPO_ID).exists() # Also test valid but unsupported repo IDs command = ["litgpt", "download", "CohereForAI/aya-23-8B"] output = run_command(command) assert "Unsupported `repo_id`" in output @pytest.mark.dependency() @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_download_books(): CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True) books = [ ("https://www.gutenberg.org/cache/epub/24440/pg24440.txt", "book1.txt"), ("https://www.gutenberg.org/cache/epub/26393/pg26393.txt", "book2.txt"), ] for url, filename in books: subprocess.run(["curl", url, "--output", str(CUSTOM_TEXTS_DIR / filename)], check=True) # Verify each book is downloaded assert (CUSTOM_TEXTS_DIR / filename).exists(), f"{filename} not downloaded" @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @pytest.mark.dependency(depends=["test_download_model"]) def test_chat_with_model(): command = ["litgpt", "generate", "checkpoints" / REPO_ID] prompt = "What do Llamas eat?" result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True) assert "What food do llamas eat?" in result.stdout @_RunIf(min_cuda_gpus=1) @pytest.mark.dependency(depends=["test_download_model"]) def test_chat_with_quantized_model(): command = ["litgpt", "generate", "checkpoints" / REPO_ID, "--quantize", "bnb.nf4", "--precision", "bf16-true"] prompt = "What do Llamas eat?" result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True) assert "What food do llamas eat?" in result.stdout, result.stdout @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @pytest.mark.dependency(depends=["test_download_model"]) @pytest.mark.timeout(300) def test_finetune_model(tmp_path): OUT_DIR = tmp_path / "out" / "lora" DATASET_PATH = tmp_path / "custom_finetuning_dataset.json" CHECKPOINT_DIR = "checkpoints" / REPO_ID download_command = [ "curl", "-L", "https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json", "-o", str(DATASET_PATH), ] subprocess.run(download_command, check=True) assert DATASET_PATH.exists(), "Dataset file not downloaded" finetune_command = [ "litgpt", "finetune_lora", str(CHECKPOINT_DIR), "--lora_r", "1", "--data", "JSON", "--data.json_path", str(DATASET_PATH), "--data.val_split_fraction", "0.00001", # Keep small because new final validation is expensive "--train.max_steps", "1", "--out_dir", str(OUT_DIR), ] run_command(finetune_command) generated_out_dir = OUT_DIR / "final" assert generated_out_dir.exists(), f"Finetuning output directory ({generated_out_dir}) was not created" model_file = OUT_DIR / "final" / "lit_model.pth" assert model_file.exists(), f"Model file ({model_file}) was not created" @pytest.mark.skipif( sys.platform.startswith("win") or sys.platform == "darwin", reason="`torch.compile` is not supported on this OS." ) @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @pytest.mark.dependency(depends=["test_download_model", "test_download_books"]) def test_pretrain_model(tmp_path): OUT_DIR = tmp_path / "out" / "custom_pretrained" pretrain_command = [ "litgpt", "pretrain", "pythia-14m", "--tokenizer_dir", str("checkpoints" / REPO_ID), "--data", "TextFiles", "--data.train_data_path", str(CUSTOM_TEXTS_DIR), "--train.max_tokens", "100", # to accelerate things for CI "--eval.max_iters", "1", # to accelerate things for CI "--out_dir", str(OUT_DIR), ] output = run_command(pretrain_command) assert "Warning: Preprocessed training data found" not in output out_dir_path = OUT_DIR / "final" assert out_dir_path.exists(), f"Pretraining output directory ({out_dir_path}) was not created" out_model_path = OUT_DIR / "final" / "lit_model.pth" assert out_model_path.exists(), f"Model file ({out_model_path}) was not created" # Test that warning is displayed when running it a second time output = run_command(pretrain_command) assert "Warning: Preprocessed training data found" in output @pytest.mark.skipif( sys.platform.startswith("win") or sys.platform == "darwin", reason="`torch.compile` is not supported on this OS." ) @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @pytest.mark.dependency(depends=["test_download_model", "test_download_books"]) def test_continue_pretrain_model(tmp_path): OUT_DIR = tmp_path / "out" / "custom_continue_pretrained" pretrain_command = [ "litgpt", "pretrain", "pythia-14m", "--initial_checkpoint", str("checkpoints" / REPO_ID), "--tokenizer_dir", str("checkpoints" / REPO_ID), "--data", "TextFiles", "--data.train_data_path", str(CUSTOM_TEXTS_DIR), "--train.max_tokens", "100", # to accelerate things for CI "--eval.max_iters", "1", # to accelerate things for CI "--out_dir", str(OUT_DIR), ] run_command(pretrain_command) generated_out_dir = OUT_DIR / "final" assert generated_out_dir.exists(), f"Continued pretraining directory ({generated_out_dir}) was not created" model_file = OUT_DIR / "final" / "lit_model.pth" assert model_file.exists(), f"Model file ({model_file}) was not created" @pytest.mark.dependency(depends=["test_download_model"]) # todo: try to resolve this issue @pytest.mark.xfail(condition=platform.system() == "Darwin", reason="it passes locally but having some issues on CI") def test_serve(): CHECKPOINT_DIR = str("checkpoints" / REPO_ID) run_command = ["litgpt", "serve", str(CHECKPOINT_DIR)] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) stdout, stderr = process.communicate(timeout=60) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() if process: kill_process_tree(process.pid) server_thread.join() ================================================ FILE: tests/test_rope.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import torch from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXRotaryEmbedding from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama from litgpt.model import apply_rope, build_rope_cache @torch.inference_mode() def test_rope_gptneox(): bs, seq_len, n_head, n_embed = 1, 6, 2, 8 head_size = n_embed // n_head # 4 x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() position_ids = torch.arange(seq_len).unsqueeze(0) config = GPTNeoXConfig(num_attention_heads=n_head, hidden_size=head_size * n_embed) theirs_rot_emb = GPTNeoXRotaryEmbedding(config) theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids) ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device) ours_cos_cached = ours_cos_cached.unsqueeze(0) ours_sin_cached = ours_sin_cached.unsqueeze(0) torch.testing.assert_close(ours_cos_cached, theirs_cos) torch.testing.assert_close(ours_sin_cached, theirs_sin) ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached) theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids) torch.testing.assert_close(ours_x_rope, theirs_x_rope) @torch.inference_mode() def test_rope_llama_2(): head_dim = 64 rope_theta = 10_000 ################################## # Compare cos and sin ################################## # transformer rope their_rope_config = { "rope_type": "default", } config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config) rot_emb = LlamaRotaryEmbedding(config=config) batch_size, seq_len = 1, 10 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) ours_cos = ours_cos.unsqueeze(0) ours_sin = ours_sin.unsqueeze(0) torch.testing.assert_close(theirs_cos, ours_cos) torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors ################################## # Settings num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) queries = torch.randn(batch_size, num_heads, seq_len, head_dim) keys = torch.randn(batch_size, num_heads, seq_len, head_dim) ours_q_rot = apply_rope(queries, ours_cos, ours_sin) ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) # See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings @torch.inference_mode() def test_rope_llama_3(): head_dim = 64 rope_theta = 50_000 ################################## # Compare cos and sin ################################## # transformer rope their_rope_config = { "rope_type": "default", } config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config) rot_emb = LlamaRotaryEmbedding(config=config) batch_size, seq_len = 1, 10 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) ours_cos = ours_cos.unsqueeze(0) ours_sin = ours_sin.unsqueeze(0) torch.testing.assert_close(theirs_cos, ours_cos) torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors ################################## # Settings num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) queries = torch.randn(batch_size, num_heads, seq_len, head_dim) keys = torch.randn(batch_size, num_heads, seq_len, head_dim) ours_q_rot = apply_rope(queries, ours_cos, ours_sin) ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) # See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings @torch.inference_mode() def test_rope_llama_3_1(): head_dim = 32 rope_theta = 50_000 their_rope_config = { "factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", } our_rope_config = {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_seq_len": 8192} config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim) ################################## # Compare cos and sin ################################## # transformer rope rot_emb = LlamaRotaryEmbedding(config=config) batch_size, seq_len = 1, 131_072 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) ours_cos = ours_cos.unsqueeze(0) ours_sin = ours_sin.unsqueeze(0) torch.testing.assert_close(theirs_cos, ours_cos) torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors ################################## # Settings num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) queries = torch.randn(batch_size, num_heads, seq_len, head_dim) keys = torch.randn(batch_size, num_heads, seq_len, head_dim) ours_q_rot = apply_rope(queries, ours_cos, ours_sin) ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) # See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings @torch.inference_mode() def test_rope_llama_3_2(): head_dim = 32 rope_theta = 50_000 their_rope_config = { "factor": 32.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", } our_rope_config = {"factor": 32.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_seq_len": 8192} config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim) ################################## # Compare cos and sin ################################## # transformer rope rot_emb = LlamaRotaryEmbedding(config=config) batch_size, seq_len = 1, 131_072 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) ours_cos = ours_cos.unsqueeze(0) ours_sin = ours_sin.unsqueeze(0) torch.testing.assert_close(theirs_cos, ours_cos) torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors ################################## # Settings num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) queries = torch.randn(batch_size, num_heads, seq_len, head_dim) keys = torch.randn(batch_size, num_heads, seq_len, head_dim) ours_q_rot = apply_rope(queries, ours_cos, ours_sin) ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) # See https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json for settings @torch.inference_mode() def test_rope_gemma_3(): from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb head_dim = 32 rope_theta = 50_000 their_rope_config = { "factor": 8.0, "rope_type": "linear", } our_rope_config = {"factor": 8.0} ################################## # Compare cos and sin ################################## # transformer rope config = Gemma3TextConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim) rot_emb = Gemma3RotaryEmbedding(config=config) batch_size, seq_len = 1, 10 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) ours_cos = ours_cos.unsqueeze(0) ours_sin = ours_sin.unsqueeze(0) torch.testing.assert_close(theirs_cos, ours_cos) torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors ################################## # Settings num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) queries = torch.randn(batch_size, num_heads, seq_len, head_dim) keys = torch.randn(batch_size, num_heads, seq_len, head_dim) ours_q_rot = apply_rope(queries, ours_cos, ours_sin) ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) @torch.inference_mode() def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(): bs, seq_len, n_head, n_embed = 1, 6, 2, 8 head_size = n_embed // n_head # 4 rotary_percentage = 0.75 rope_n_elem = int(head_size * rotary_percentage) # 3 ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem) required_shape = (seq_len, rope_n_elem) assert ours_cos.shape == required_shape assert ours_sin.shape == required_shape # Special case: If `rope_n_elem == 1`, the shape is extended. This is to # accommodate a current bug in Hugging Face, ensuring that other unit tests # pass. # https://github.com/huggingface/transformers/issues/35233 rotary_percentage = 0.25 rope_n_elem = int(head_size * rotary_percentage) # 1 ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem) required_shape = (seq_len, rope_n_elem + 1) assert ours_cos.shape == required_shape assert ours_sin.shape == required_shape ================================================ FILE: tests/test_serve.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import json import platform import shutil import subprocess import threading import time from dataclasses import asdict import pytest import requests import torch import yaml from lightning.fabric import seed_everything from urllib3.exceptions import MaxRetryError from litgpt import GPT, Config from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf, kill_process_tree def _wait_and_check_response(waiting: int = 30): response_status_code, err = -1, None for _ in range(waiting): try: response = requests.get("http://127.0.0.1:8000", timeout=10) response_status_code = response.status_code except (MaxRetryError, requests.exceptions.ConnectionError) as ex: response_status_code = -1 err = str(ex) if response_status_code == 200: break time.sleep(1) assert response_status_code == 200, f"Server did not respond as expected. Error: {err}" # todo: try to resolve this issue @pytest.mark.flaky(reruns=2, reruns_delay=30) @pytest.mark.xfail(condition=platform.system() == "Darwin", reason="it passes locally but having some issues on CI") def test_simple(tmp_path): seed_everything(123) ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) run_command = ["litgpt", "serve", tmp_path] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response(waiting=60) if process: kill_process_tree(process.pid) server_thread.join() @_RunIf(min_cuda_gpus=1) def test_quantize(tmp_path): seed_everything(123) ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) run_command = ["litgpt", "serve", tmp_path, "--quantize", "bnb.nf4"] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() if process: kill_process_tree(process.pid) server_thread.join() @_RunIf(min_cuda_gpus=2) def test_multi_gpu_serve(tmp_path): seed_everything(123) ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) run_command = ["litgpt", "serve", tmp_path, "--devices", "2"] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() if process: kill_process_tree(process.pid) server_thread.join() @_RunIf(min_cuda_gpus=1) def test_serve_with_openai_spec_missing_chat_template(tmp_path): seed_everything(123) ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() if process: kill_process_tree(process.pid) server_thread.join() @_RunIf(min_cuda_gpus=1) def test_serve_with_openai_spec(tmp_path): seed_everything(123) ours_config = Config.from_name("SmolLM2-135M-Instruct") download_from_hub(repo_id="HuggingFaceTB/SmolLM2-135M-Instruct", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() try: # Test server health response = requests.get("http://127.0.0.1:8000/health") assert response.status_code == 200, f"Server health check failed with status code {response.status_code}" assert response.text == "ok", "Server did not respond as expected." # Test non-streaming chat completion response = requests.post( "http://127.0.0.1:8000/v1/chat/completions", json={ "model": "SmolLM2-135M-Instruct", "messages": [{"role": "user", "content": "Hello!"}], }, ) assert response.status_code == 200, ( f"Non-streaming chat completion failed with status code {response.status_code}" ) response_json = response.json() assert "choices" in response_json, "Response JSON does not contain 'choices'." assert "message" in response_json["choices"][0], "Response JSON does not contain 'message' in 'choices'." assert "content" in response_json["choices"][0]["message"], ( "Response JSON does not contain 'content' in 'message'." ) assert response_json["choices"][0]["message"]["content"], "Content is empty in the response." # Test streaming chat completion stream_response = requests.post( "http://127.0.0.1:8000/v1/chat/completions", json={ "model": "SmolLM2-135M-Instruct", "messages": [{"role": "user", "content": "Hello!"}], "stream": True, }, ) assert stream_response.status_code == 200, ( f"Streaming chat completion failed with status code {stream_response.status_code}" ) for line in stream_response.iter_lines(): decoded = line.decode("utf-8").replace("data: ", "").replace("[DONE]", "").strip() if decoded: data = json.loads(decoded) assert "choices" in data, "Response JSON does not contain 'choices'." assert "delta" in data["choices"][0], "Response JSON does not contain 'delta' in 'choices'." assert "content" in data["choices"][0]["delta"], "Response JSON does not contain 'content' in 'delta'." finally: if process: kill_process_tree(process.pid) server_thread.join() @pytest.mark.parametrize( "generate_strategy", [ pytest.param("sequential", marks=_RunIf(min_cuda_gpus=1)), pytest.param("tensor_parallel", marks=_RunIf(min_cuda_gpus=2)), ], ) def test_serve_with_generate_strategy(tmp_path, generate_strategy): seed_everything(123) ours_config = Config.from_name("pythia-14m") download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" torch.save(ours_model.state_dict(), checkpoint_path) config_path = tmp_path / "model_config.yaml" with open(config_path, "w", encoding="utf-8") as fp: yaml.dump(asdict(ours_config), fp) # Test with generate strategy run_command = ["litgpt", "serve", tmp_path, "--generate_strategy", generate_strategy] process = None def run_server(): nonlocal process try: process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True) except subprocess.TimeoutExpired: print("Server start-up timeout expired") server_thread = threading.Thread(target=run_server) server_thread.start() _wait_and_check_response() if process: kill_process_tree(process.pid) server_thread.join() ================================================ FILE: tests/test_tokenizer.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os import shutil import warnings from types import SimpleNamespace from unittest import mock import pytest from tokenizers import Tokenizer as HFTokenizer from tokenizers.models import BPE from transformers import AutoTokenizer from transformers.utils import cached_file import litgpt.config as config_module from litgpt import PromptStyle, Tokenizer # @pytest.mark.flaky(reruns=3, rerun_except=["AssertionError", "assert", "TypeError"]) @pytest.mark.flaky(reruns=3, reruns_delay=120) @pytest.mark.parametrize("config", config_module.configs, ids=[c["hf_config"]["name"] for c in config_module.configs]) def test_tokenizer_against_hf(config, tmp_path): config = config_module.Config(**config) repo_id = f"{config.hf_config['org']}/{config.hf_config['name']}" theirs = AutoTokenizer.from_pretrained(repo_id, token=os.getenv("HF_TOKEN")) # create a checkpoint directory that points to the HF files hf_files = {} for filename in ("tokenizer.json", "generation_config.json", "tokenizer.model", "tokenizer_config.json"): try: # download the HF tokenizer config hf_file = cached_file(path_or_repo_id=repo_id, filename=filename) hf_files[filename] = str(hf_file) except Exception as ex: warnings.warn(str(ex), RuntimeWarning) if "tokenizer.json" not in hf_files and "tokenizer.model" not in hf_files: raise ConnectionError("Unable to download any tokenizer files from HF") # Create a clean, model-specific subdirectory for this test run. # This avoids errors if previous runs or retries left files behind, ensuring the directory is always ready for fresh downloads and comparisons. model_dir = tmp_path / config.hf_config["name"] if model_dir.exists(): shutil.rmtree(model_dir) os.makedirs(model_dir, exist_ok=True) for filename, hf_file in hf_files.items(): shutil.copy(hf_file, model_dir / filename) ours = Tokenizer(model_dir) assert ours.vocab_size == theirs.vocab_size if config.name == "Mixtral-8x22B-v0.1": pytest.xfail(reason="Mixtral certainly lists 32000 vocab in its config") else: assert ours.vocab_size == config.vocab_size if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ", "Qwen3")): # even though their config defines it, it's set as None in HF assert isinstance(ours.bos_id, int) assert theirs.bos_token_id is None elif config.name.startswith("Falcon3"): if isinstance(ours.bos_id, int): assert theirs.bos_token_id is None else: assert ours.bos_id == theirs.bos_token_id is None else: assert ours.bos_id == theirs.bos_token_id if config.name.startswith("stablecode"): # even though their config defines it, it's set as None in HF assert ours.eos_id == 0 assert ours.eos_id == theirs.eos_token_id or theirs.eos_token_id is None else: assert ours.eos_id == theirs.eos_token_id prompt = "Hello, readers of this test!" prompt = PromptStyle.from_config(config).apply(prompt) actual = ours.encode(prompt) expected = theirs.encode(prompt) assert actual.tolist() == expected assert ours.decode(actual) == theirs.decode(expected, skip_special_tokens=True) if not config.name.startswith(("Mistral", "Mixtral")): decoded_output = "".join([ours.decode(x) for x in actual]) if ours.apply_decoding_fix and decoded_output[0] == " ": decoded_output = decoded_output[1:] # the "hack" adds an empty space to the beginning assert decoded_output == ours.decode(actual), type(theirs) def test_tokenizer_input_validation(): with pytest.raises(NotADirectoryError, match="The checkpoint directory does not exist"): Tokenizer("cocofruit") @pytest.mark.parametrize("use_bos_by_default", (True, False)) @pytest.mark.parametrize("encode_use_bos", (None, True, False)) @pytest.mark.parametrize("encode_use_eos", (True, False)) @pytest.mark.parametrize("processor_returns_bos", (True, False)) @pytest.mark.parametrize("fake_return_ids", ([], [34, 8, 17, 2])) def test_tokenizer_bos_eos( tmp_path, use_bos_by_default, encode_use_bos, encode_use_eos, processor_returns_bos, fake_return_ids ): # let `Tokenizers` create a proper (albeit empty) vocab in json format HFTokenizer(BPE()).save(str(tmp_path / "tokenizer.json")) tokenizer = Tokenizer(tmp_path) tokenizer.bos_id = 0 tokenizer.eos_id = 1 tokenizer.use_bos = use_bos_by_default if processor_returns_bos: fake_return_ids = [tokenizer.bos_id] + fake_return_ids fake_return_ids = SimpleNamespace(**dict(ids=fake_return_ids)) with mock.patch.object(tokenizer.processor, "encode", return_value=fake_return_ids): tokens = tokenizer.encode("Hello world", bos=encode_use_bos, eos=encode_use_eos).tolist() if encode_use_bos or (encode_use_bos is None and use_bos_by_default): assert tokens[0] == tokenizer.bos_id else: assert not tokens or tokens[0] != tokenizer.bos_id if encode_use_eos: assert tokens[-1] == tokenizer.eos_id else: assert not tokens or tokens[-1] != tokenizer.eos_id # both `bos` and `eos` should either not be found or occur only once at the begging (bos) # or at the end (eos) of the tokens sequence assert max([id for id, token in enumerate(tokens) if token == tokenizer.bos_id], default=0) == 0 assert max([id for id, token in enumerate(tokens[::-1]) if token == tokenizer.eos_id], default=0) == 0 ================================================ FILE: tests/test_trainer_support.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from pathlib import Path import lightning as L import pytest import torch from litgpt.api import LLM from litgpt.data import Alpaca2k from litgpt.utils import _RunIf REPO_ID = Path("EleutherAI/pythia-14m") class LitLLM(L.LightningModule): def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None): super().__init__() self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None) self.trainer_ckpt_path = trainer_ckpt_path def setup(self, stage): self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path) def training_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("validation_loss", loss, prog_bar=True) return loss def configure_optimizers(self): warmup_steps = 10 optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) return [optimizer], [scheduler] @pytest.mark.dependency() def test_download_model(): LLM.load(model="EleutherAI/pythia-14m", distribute=None) @pytest.mark.dependency(depends=["test_download_model"]) @_RunIf(min_cuda_gpus=1) def test_usecase1_pretraining_from_random_weights(tmp_path): llm = LLM.load("EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random") llm.save("pythia-14m-random-weights") del llm lit_model = LitLLM(checkpoint_dir="pythia-14m-random-weights", tokenizer_dir="EleutherAI/pythia-14m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128) trainer = L.Trainer( max_epochs=1, overfit_batches=2, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) text = lit_model.llm.generate("hello world") assert isinstance(text, str) @pytest.mark.dependency(depends=["test_download_model"]) @_RunIf(min_cuda_gpus=1) def test_usecase2_continued_pretraining_from_checkpoint(tmp_path): lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-14m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128) trainer = L.Trainer( accelerator="cuda", max_epochs=1, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) text = lit_model.llm.generate("hello world") assert isinstance(text, str) @pytest.mark.dependency(depends=["test_download_model", "test_usecase2_continued_pretraining_from_checkpoint"]) @_RunIf(min_cuda_gpus=1) def test_usecase3_resume_from_trainer_checkpoint(tmp_path): def find_latest_checkpoint(directory): latest_checkpoint = None latest_time = 0 for root, _, files in os.walk(directory): for file in files: if file.endswith(".ckpt"): file_path = os.path.join(root, file) file_time = os.path.getmtime(file_path) if file_time > latest_time: latest_time = file_time latest_checkpoint = file_path return latest_checkpoint lit_model = LitLLM( checkpoint_dir="EleutherAI/pythia-14m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs") ) data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128) trainer = L.Trainer( accelerator="cuda", max_epochs=1, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) text = lit_model.llm.generate("hello world") assert isinstance(text, str) @pytest.mark.dependency(depends=["test_download_model", "test_usecase2_continued_pretraining_from_checkpoint"]) @_RunIf(min_cuda_gpus=1) def test_usecase4_manually_save_and_resume(tmp_path): lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-14m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128) trainer = L.Trainer( accelerator="cuda", max_epochs=1, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) text = lit_model.llm.generate("hello world") assert isinstance(text, str) lit_model.llm.save("finetuned_checkpoint") del lit_model lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint") trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) text = lit_model.llm.generate("hello world") assert isinstance(text, str) ================================================ FILE: tests/test_types.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. from typing import get_args from litgpt.constants import _SUPPORTED_LOGGERS from litgpt.types import LoggerChoice def test_logger_types_match_constants(): """Ensure LoggerChoice and _SUPPORTED_LOGGERS stay synchronized.""" logger_choice_args = get_args(LoggerChoice) assert logger_choice_args == _SUPPORTED_LOGGERS, ( f"LoggerChoice type args {logger_choice_args} != " f"_SUPPORTED_LOGGERS {_SUPPORTED_LOGGERS}. " f"These must stay synchronized. Update both litgpt/types.py and " f"litgpt/constants.py when adding new loggers." ) ================================================ FILE: tests/test_utils.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stderr from dataclasses import asdict from io import StringIO from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from unittest import mock import pytest import torch import torch.nn.functional as F import yaml from lightning import Fabric from lightning.fabric.loggers import CSVLogger, TensorBoardLogger from lightning.fabric.plugins import BitsandbytesPrecision from lightning.pytorch.loggers import LitLogger, MLFlowLogger, WandbLogger from litgpt import GPT from litgpt.args import TrainArgs from litgpt.constants import ( _LITLOGGER_AVAILABLE, _MLFLOW_AVAILABLE, _MLFLOW_SKINNY_AVAILABLE, _TENSORBOARD_AVAILABLE, _WANDB_AVAILABLE, ) from litgpt.parser_config import save_hyperparameters from litgpt.utils import ( CLI, CycleIterator, _RunIf, capture_hparams, check_file_size_on_cpu_and_warn, check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, copy_config_files, extend_checkpoint_dir, find_resume_path, fix_and_load_json, incremental_save, init_out_dir, instantiate_bnb_optimizer, instantiate_torch_optimizer, num_parameters, parse_devices, select_sft_generate_example, ) # match fails on windows. why did they have to use backslashes? @_RunIf(skip_windows=True) def test_check_valid_checkpoint_dir(tmp_path): os.chdir(tmp_path) out = StringIO() with pytest.raises(SystemExit), redirect_stderr(out): check_valid_checkpoint_dir(tmp_path) out = out.getvalue().strip() expected = f""" checkpoint_dir '{str(tmp_path.absolute())}' is missing the files: ['lit_model.pth', 'model_config.yaml', 'tokenizer.json OR tokenizer.model', 'tokenizer_config.json']. Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials See all download options by running: litgpt download """.strip() assert out == expected out = StringIO() checkpoint_dir = tmp_path / "checkpoints" / "stabilityai" / "stablelm-base-alpha-3b" with pytest.raises(SystemExit), redirect_stderr(out): check_valid_checkpoint_dir(checkpoint_dir) out = out.getvalue().strip() expected = f""" checkpoint_dir '{str(checkpoint_dir.absolute())}' is not a checkpoint directory. Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials See all download options by running: litgpt download """.strip() assert out == expected out = StringIO() checkpoint_dir.mkdir(parents=True) foo_checkpoint_dir = tmp_path / "foo" with pytest.raises(SystemExit), redirect_stderr(out): check_valid_checkpoint_dir(foo_checkpoint_dir) out = out.getvalue().strip() expected = f""" checkpoint_dir '{str(foo_checkpoint_dir.absolute())}' is not a checkpoint directory. Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials You have downloaded locally: '{str(checkpoint_dir.absolute())}' See all download options by running: litgpt download """.strip() assert out == expected def test_incremental_write(tmp_path): sd = {str(k): torch.randn(5, 10) for k in range(3)} sd["0"].someattr = 1 sd_expected = {k: v.clone() for k, v in sd.items()} fn = str(tmp_path / "test.pt") with incremental_save(fn) as f: sd["0"] = f.store_early(sd["0"]) sd["2"] = f.store_early(sd["2"]) f.save(sd) sd_actual = torch.load(fn) assert sd_actual.keys() == sd_expected.keys() assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+ for k, v_expected in sd_expected.items(): v_actual = sd_actual[k] torch.testing.assert_close(v_expected, v_actual) sd_actual = torch.load(fn, weights_only=True) assert sd_actual.keys() == sd_expected.keys() assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+ for k, v_expected in sd_expected.items(): v_actual = sd_actual[k] torch.testing.assert_close(v_expected, v_actual) @pytest.mark.parametrize("B", (1, 2)) @pytest.mark.parametrize("ignore_index", (None, -1, -2, -100)) def test_chunked_cross_entropy(ignore_index, B): V = 50 T = 25 regular_logits = torch.randn(B, T, V) targets = torch.randint(0, V, (B, T)) if ignore_index is not None: targets[:, [1, 4, 10, 19]] = ignore_index baseline_loss = F.cross_entropy( regular_logits.reshape(-1, regular_logits.size(-1)), targets.reshape(-1), ignore_index=(ignore_index if ignore_index is not None else -100), ) ignore_index = ignore_index if ignore_index is not None else -100 regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0, ignore_index=ignore_index) assert torch.equal(baseline_loss, regular_loss) assert regular_loss.numel() == 1 chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10, ignore_index=ignore_index) torch.testing.assert_close(chunked_loss, regular_loss) torch.testing.assert_close(chunked_loss, baseline_loss) logit_chunk_size = 6 assert T % logit_chunk_size != 0 # ensure leftover chunked_logits = list(regular_logits.split(logit_chunk_size, dim=1)) chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0, ignore_index=ignore_index) torch.testing.assert_close(chunked_loss, regular_loss) torch.testing.assert_close(chunked_loss, baseline_loss) chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10, ignore_index=ignore_index) torch.testing.assert_close(chunked_loss, regular_loss) torch.testing.assert_close(chunked_loss, baseline_loss) def test_num_parameters(): model = torch.nn.Linear(2, 2) assert num_parameters(model) == 6 assert num_parameters(model, requires_grad=True) == 6 assert num_parameters(model, requires_grad=False) == 0 model = torch.nn.Linear(2, 2) model.bias.requires_grad = False assert num_parameters(model) == 6 assert num_parameters(model, requires_grad=True) == 4 assert num_parameters(model, requires_grad=False) == 2 @_RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("mode", ["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"]) def test_num_parameters_bitsandbytes(mode): plugin = BitsandbytesPrecision(mode=mode) fabric = Fabric(plugins=plugin, accelerator="cuda", devices=1) model = torch.nn.Linear(10, 10) model = fabric.setup(model) assert num_parameters(model) == 110 with fabric.init_module(empty_init=True): model = GPT.from_name("pythia-14m") assert num_parameters(model) == 14067712 def test_cycle_iterator(): iterator = CycleIterator([]) with pytest.raises(StopIteration): next(iterator) iterator = CycleIterator(range(3)) assert iterator.epoch == 0 assert next(iterator) == 0 assert iterator.epoch == 0 assert next(iterator) == 1 assert iterator.epoch == 0 assert next(iterator) == 2 assert iterator.epoch == 0 assert next(iterator) == 0 assert iterator.epoch == 1 def test_parse_devices(): with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): assert parse_devices(0) with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): assert parse_devices(-2) with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=0): assert parse_devices("auto") == 1 # CPU assert parse_devices(10) == 10 # leave validation up to Fabric later on with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=1): assert parse_devices("auto") == 1 # CUDA with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=3): assert parse_devices("auto") == 3 assert parse_devices(-1) == 3 assert parse_devices(5) == 5 def test_copy_config_files(fake_checkpoint_dir, tmp_path): copy_config_files(fake_checkpoint_dir, tmp_path) expected = {"model_config.yaml", "tokenizer_config.json", "tokenizer.json"} contents = set(os.listdir(tmp_path)) assert expected.issubset(contents) def test_capture_hparams(): integer = 1 string = "string" boolean = True none = None path = Path("/path") dataclass = TrainArgs() other = torch.nn.Linear(1, 1) hparams = capture_hparams() assert hparams == { "integer": integer, "string": string, "boolean": boolean, "none": none, "path": path, "dataclass": asdict(dataclass), "other": str(other), } def _test_function(out_dir: Path, foo: bool = False, bar: int = 1): save_hyperparameters(_test_function, out_dir) def test_save_hyperparameters(tmp_path): with mock.patch("sys.argv", ["any.py", str(tmp_path), "--foo", "True"]): CLI(_test_function) with open(tmp_path / "hyperparameters.yaml", encoding="utf-8") as file: hparams = yaml.full_load(file) assert hparams["out_dir"] == str(tmp_path) assert hparams["foo"] is True assert hparams["bar"] == 1 def _test_function2(out_dir: Path, foo: bool = False, bar: int = 1): assert False, "I only exist as a signature, but I should not run." @pytest.mark.parametrize( "command", [ "any.py", "litgpt finetune", "litgpt finetune_full", "litgpt finetune_lora", "litgpt finetune_adapter", "litgpt finetune_adapter_v2", "litgpt pretrain", ], ) def test_save_hyperparameters_known_commands(command, tmp_path): with mock.patch("sys.argv", [*command.split(" "), str(tmp_path), "--foo", "True"]): save_hyperparameters(_test_function2, tmp_path) with open(tmp_path / "hyperparameters.yaml", encoding="utf-8") as file: hparams = yaml.full_load(file) assert hparams["out_dir"] == str(tmp_path) assert hparams["foo"] is True assert hparams["bar"] == 1 def test_choose_logger(tmp_path): assert isinstance(choose_logger("csv", out_dir=tmp_path, name="csv"), CSVLogger) if _TENSORBOARD_AVAILABLE: assert isinstance(choose_logger("tensorboard", out_dir=tmp_path, name="tb"), TensorBoardLogger) if _WANDB_AVAILABLE: assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger) if _MLFLOW_AVAILABLE or _MLFLOW_SKINNY_AVAILABLE: assert isinstance(choose_logger("mlflow", out_dir=tmp_path, name="wandb"), MLFlowLogger) if _LITLOGGER_AVAILABLE: assert isinstance(choose_logger("litlogger", out_dir=tmp_path, name="litlogger"), LitLogger) with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."): choose_logger("foo", out_dir=tmp_path, name="foo") @pytest.mark.parametrize( "path_type, input_path, expected", [ ("relative", "some/relative/path", "some/relative/path"), ("absolute", "/usr/absolute/path", "/usr/absolute/path"), ("env_relative", "some/relative/path", "prefix/some/relative/path"), ("env_absolute", "/usr/absolute/path", "/usr/absolute/path"), ], ) def test_init_out_dir(path_type, input_path, expected): if path_type.startswith("env_"): with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}): result = init_out_dir(input_path) assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})" else: result = init_out_dir(input_path) if "LIGHTNING_ARTIFACTS_DIR" not in os.environ: assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})" else: assert result == Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / expected, ( f"Failed for {path_type} with input {input_path} (result {result})" ) def test_find_resume_path(tmp_path): assert find_resume_path(resume=None, out_dir=Path("does/not/exist")) is None assert find_resume_path(resume=Path("does/not/exist"), out_dir=Path("does/not/matter")) == Path("does/not/exist") assert find_resume_path(resume=(tmp_path / "checkpoint.pt"), out_dir=Path("does/not/matter")) == ( tmp_path / "checkpoint.pt" ) # `resume='auto'` does not enforce the checkpoint to exist assert find_resume_path(resume="auto", out_dir=Path("does/not/exist")) is None # `resume=True` requires a checkpoint to exist with pytest.raises(FileNotFoundError, match="You passed `--resume=True`, but no checkpoint file was found"): find_resume_path(resume=True, out_dir=Path("does/not/exist")) with pytest.raises(FileNotFoundError, match="You passed `--resume=True`, but no checkpoint file was found"): find_resume_path(resume=True, out_dir=tmp_path) (tmp_path / "step-001").mkdir() (tmp_path / "step-001" / "lit_model.pth").touch() (tmp_path / "step-002").mkdir() (tmp_path / "step-002" / "lit_model.pth").touch() (tmp_path / "step-003").mkdir() (tmp_path / "step-003" / "lit_model.pth").touch() assert find_resume_path(resume=True, out_dir=tmp_path) == (tmp_path / "step-003" / "lit_model.pth") assert find_resume_path(resume="auto", out_dir=tmp_path) == (tmp_path / "step-003" / "lit_model.pth") @pytest.fixture def model_parameters(): return [torch.nn.Parameter(torch.randn(2, 2))] def test_instantiate_bnb_optimizer_with_str(model_parameters): import bitsandbytes as bnb with mock.patch("litgpt.utils.get_argument_names", return_value={"lr", "eps", "weight_decay"}): optimizer = instantiate_bnb_optimizer("AdamW", model_parameters) assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW) def test_instantiate_bnb_optimizer_with_dict(model_parameters): import bitsandbytes as bnb optimizer_dict = {"class_path": "AdamW", "init_args": {"lr": 0.01}} with mock.patch("litgpt.utils.get_argument_names", return_value={"lr", "eps", "weight_decay"}): optimizer = instantiate_bnb_optimizer(optimizer_dict, model_parameters) assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW) assert optimizer.param_groups[0]["lr"] == 0.01 def test_instantiate_bnb_optimizer_with_invalid_str(model_parameters): with pytest.raises(ValueError, match="only supports the AdamW"): instantiate_bnb_optimizer("SGD", model_parameters) def test_instantiate_torch_optimizer_with_str(model_parameters): optimizer = instantiate_torch_optimizer("Adam", model_parameters, lr=0.01) assert isinstance(optimizer, torch.optim.Adam) assert optimizer.param_groups[0]["lr"] == 0.01 def test_instantiate_torch_optimizer_with_class(model_parameters): optimizer = instantiate_torch_optimizer( {"class_path": "torch.optim.Adam", "init_args": {"lr": 123}}, model_parameters, lr=0.02 ) assert isinstance(optimizer, torch.optim.Adam) # init args gets overridden assert optimizer.param_groups[0]["lr"] == 0.02 @pytest.mark.parametrize( "input_path, expected", [ (Path("checkpoints/my_model"), Path("checkpoints/my_model")), (Path("checkpoints/my_model"), Path("./checkpoints/my_model")), ], ) def test_extend_checkpoint_dir_is_prefixed(input_path, expected): original_dir = Path.cwd() # Save the current directory with TemporaryDirectory() as tmp_dir: os.chdir(tmp_dir) try: if not input_path.is_absolute(): input_path = Path(tmp_dir) / input_path if not expected.is_absolute(): expected = Path(tmp_dir) / expected input_path.parent.mkdir(parents=True, exist_ok=True) input_path.touch(exist_ok=True) assert extend_checkpoint_dir(input_path) == expected finally: os.chdir(original_dir) # Reset the current directory @pytest.mark.parametrize( "input_path, expected", [ (Path("my_model"), Path("checkpoints/my_model")), (Path("my_model"), Path("./checkpoints/my_model")), ], ) def test_extend_checkpoint_dir(input_path, expected): original_dir = Path.cwd() # Save the current directory with TemporaryDirectory() as tmp_dir: os.chdir(tmp_dir) try: if not input_path.is_absolute(): input_path = Path(tmp_dir) / "checkpoints" / input_path if not expected.is_absolute(): expected = Path(tmp_dir) / expected input_path.parent.mkdir(parents=True, exist_ok=True) input_path.touch(exist_ok=True) assert extend_checkpoint_dir(input_path) == expected finally: os.chdir(original_dir) # Reset the current directory @pytest.mark.parametrize( "input_path, expected", [ (Path("my_model"), Path("my_model")), (Path("/my_model"), Path("/my_model")), ], ) def test_extend_checkpoint_dir_dont_exist(input_path, expected): assert extend_checkpoint_dir(input_path) == expected def test_file_size_below_limit_on_cpu(): # Test file size below limit on CPU with NamedTemporaryFile() as temp_file: with mock.patch("os.path.getsize", return_value=4_000_000_000): size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu") assert size == 4_000_000_000 def test_file_size_above_limit_on_cpu(): # Test file size above limit on CPU with NamedTemporaryFile() as temp_file: with mock.patch("os.path.getsize", return_value=4_600_000_000): with pytest.warns(UserWarning) as record: size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu") assert size == 4_600_000_000 assert "over 4.2 GB" in str(record[0].message) def test_file_size_above_limit_on_gpu(): # Test file size above limit on GPU should not warn with NamedTemporaryFile() as temp_file: with mock.patch("os.path.getsize", return_value=4_600_000_000): size = check_file_size_on_cpu_and_warn(temp_file.name, "gpu") assert size == 4_600_000_000 @pytest.fixture def mock_cuda_is_available_true(monkeypatch): """Fixture to mock torch.cuda.is_available() to return True.""" monkeypatch.setattr(torch.cuda, "is_available", lambda: True) @pytest.fixture def mock_nvidia_device_properties(monkeypatch): """Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs.""" mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"]) mock_device_properties.name = "NVIDIA RTX A6000" monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties) @pytest.fixture def mock_amd_device_properties(monkeypatch): """Fixture to mock torch.cuda.get_device_properties() for AMD GPUs.""" mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"]) mock_device_properties.name = "AMD Instinct MI250X" monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties) @pytest.fixture def all_nvlink_connected_output(): return mock.MagicMock( stdout=""" GPU0 GPU1 GPU2 GPU3 GPU0 X NV12 NV12 NV12 GPU1 NV12 X NV12 NV12 GPU2 NV12 NV12 X NV12 GPU3 NV12 NV12 NV12 X""", returncode=0, ) @mock.patch("subprocess.run") def test_all_nvlink_connected( mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties ): mock_run.return_value = all_nvlink_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("All GPUs are fully connected via NVLink.") @pytest.fixture def nvlink_partially_connected_output(): return mock.MagicMock( stdout=""" GPU0 GPU1 GPU2 GPU3 CPU Affinity GPU0 X NV1 SYS SYS 0-7 GPU1 NV1 X SYS SYS 0-7 GPU2 SYS SYS X NV1 8-15 GPU3 SYS SYS NV1 X 8-15 Legend: X = Self NV1 = Connected via NVLink with 1 hop SYS = Connected via the PCIe or CPU subsystem""", returncode=0, ) @mock.patch("subprocess.run") def test_nvlink_partially_connected_output( mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties ): mock_run.return_value = nvlink_partially_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call( "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." ) @pytest.fixture def nvlink_not_connected_output(): return mock.MagicMock( stdout=""" GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X PHB PHB PHB 0-47 0 N/A GPU1 PHB X PHB PHB 0-47 0 N/A GPU2 PHB PHB X PHB 0-47 0 N/A GPU3 PHB PHB PHB X 0-47 0 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks""", returncode=0, ) @mock.patch("subprocess.run") def test_nvlink_not_connected_output( mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties ): mock_run.return_value = nvlink_not_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call( "Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. " "It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance." ) @pytest.fixture def nvlink_all_gpu_connected_but_other_connected_output(): return mock.MagicMock( stdout=""" GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 NIC9 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12 SYS SYS SYS SYS SYS SYS SYS SYS PXB PXB 64-127,192-254 1 N/A GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 SYS SYS SYS SYS SYS SYS SYS SYS PXB PXB 64-127,192-254 1 N/A GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A NIC0 SYS SYS PXB PXB SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS SYS SYS SYS NIC1 SYS SYS PXB PXB SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS SYS SYS SYS NIC2 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS SYS SYS SYS NIC3 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS PXB X SYS SYS SYS SYS SYS SYS NIC4 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS X PXB SYS SYS SYS SYS NIC5 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS PXB X SYS SYS SYS SYS NIC6 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PIX SYS SYS NIC7 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PIX X SYS SYS NIC8 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PXB NIC9 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PXB X Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks NIC Legend: NIC0: mlx5_0 NIC1: mlx5_1 NIC2: mlx5_2 NIC3: mlx5_3 NIC4: mlx5_4 NIC5: mlx5_5 NIC6: mlx5_6 NIC7: mlx5_7 NIC8: mlx5_8 NIC9: mlx5_9 """, returncode=0, ) @mock.patch("subprocess.run") def test_nvlink_all_gpu_connected_but_other_connected_output( mock_run, nvlink_all_gpu_connected_but_other_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties, ): mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("All GPUs are fully connected via NVLink.") @pytest.fixture def nvidia_smi_nvlink_output_dual_gpu_no_numa(): return mock.MagicMock( stdout=""" GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X NV1 0-15 0 N/A GPU1 NV1 X 0-15 0 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks """, returncode=0, ) @mock.patch("subprocess.run") def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus( mock_run, nvidia_smi_nvlink_output_dual_gpu_no_numa, mock_cuda_is_available_true, mock_nvidia_device_properties ): mock_run.return_value = nvidia_smi_nvlink_output_dual_gpu_no_numa with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("All GPUs are fully connected via NVLink.") @pytest.fixture def rocm_smi_xgmi_output_multi_gpu(): """ rocm-smi --showtopotype on ROCm 6.0.3+ """ return mock.MagicMock( stdout=""" =============================== ROCm System Management Interface ============================ =============================== Link Type between two GPUs =============================== GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0 ================================== End of ROCm SMI Log =================================== """, returncode=0, ) @mock.patch("subprocess.run") def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus( mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties ): mock_run.return_value = rocm_smi_xgmi_output_multi_gpu with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("All GPUs are fully connected via XGMI.") @mock.patch("subprocess.run") def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch): monkeypatch.setattr(torch.cuda, "is_available", lambda: False) with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("No GPUs available") @mock.patch("subprocess.run") def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor( mock_run, monkeypatch, mock_cuda_is_available_true ): mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"]) mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU" monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties) with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() mock_print.assert_any_call("Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU") def test_fix_and_load_json(): # Test 1: Invalid JSON string with a trailing comma invalid_json_trailing_comma = """ { "_from_model_config": true, "bos_token_id": 128000, "eos_token_id": 128001, "transformers_version": "4.45.0.dev0", "do_sample": true, "temperature": 0.6, "top_p": 0.9, } """ expected_output_trailing_comma = { "_from_model_config": True, "bos_token_id": 128000, "eos_token_id": 128001, "transformers_version": "4.45.0.dev0", "do_sample": True, "temperature": 0.6, "top_p": 0.9, } result_trailing_comma = fix_and_load_json(invalid_json_trailing_comma) assert result_trailing_comma == expected_output_trailing_comma # Test 2: Invalid JSON string with missing commas between properties invalid_json_missing_commas = """ { "_from_model_config": true, "bos_token_id": 128000, "eos_token_id": 128001, "transformers_version": "4.45.0.dev0" "do_sample": true, "temperature": 0.6, "top_p": 0.9, } """ expected_output_missing_commas = { "_from_model_config": True, "bos_token_id": 128000, "eos_token_id": 128001, "transformers_version": "4.45.0.dev0", "do_sample": True, "temperature": 0.6, "top_p": 0.9, } result_missing_commas = fix_and_load_json(invalid_json_missing_commas) assert result_missing_commas == expected_output_missing_commas def test_select_sft_generate_example(): eval_mock = mock.MagicMock() data_mock = mock.MagicMock() test_dataset = {"data": [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]} train_dataset = {"data": [{"instruction": "Train instruction 1"}, {"instruction": "Train instruction 2"}]} data_mock.test_dataset.data = test_dataset["data"] data_mock.train_dataset.data = train_dataset["data"] # Test "first" instruction from test dataset eval_mock.evaluate_example = "first" instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Test instruction 1" # Test "first" instruction from train dataset when test dataset is empty data_mock.test_dataset.data = [] instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Train instruction 1" # Test random selection from test dataset eval_mock.evaluate_example = "random" data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}] with mock.patch("random.randint", return_value=1): instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Test instruction 2" # Test random selection from train dataset when test dataset is empty data_mock.test_dataset.data = [] with mock.patch("random.randint", return_value=1): instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Train instruction 2" # Test specific index from test dataset eval_mock.evaluate_example = 1 data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}] instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Test instruction 2" # Test specific index from train dataset when test dataset has fewer elements data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}] instruction = select_sft_generate_example(eval_mock, data_mock) assert instruction == "Train instruction 2" # Test out-of-range index eval_mock.evaluate_example = 2 data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}] data_mock.train_dataset.data = [{"instruction": "Train instruction 1"}] with pytest.raises(IndexError): select_sft_generate_example(eval_mock, data_mock) # Test unknown evaluation type eval_mock.evaluate_example = "unknown" with pytest.raises(ValueError): select_sft_generate_example(eval_mock, data_mock) ================================================ FILE: tests/test_yarn.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import pytest import torch from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from litgpt import Config from litgpt.model import Block @torch.inference_mode() @pytest.mark.parametrize("batch_size", (1, 2)) @pytest.mark.parametrize("seq_len", (8, 16)) @pytest.mark.parametrize("device", [torch.device("cpu")]) def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): """Test DeepSeek V3 block (attention + MLP + norms) with YaRN RoPE scaling - litgpt vs hf""" # Use layer_idx=0 to test dense MLP instead of MoE layer_idx = 0 # YaRN configuration yarn_config = dict( factor=8.0, beta_fast=32.0, beta_slow=1.0, original_max_seq_len=4096, mscale=1.0, mscale_all_dim=0.8, ) config_litgpt = Config( n_embd=64, n_head=4, n_query_groups=4, head_size=16, norm_eps=1e-6, norm_class_name="RMSNorm", bias=False, parallel_residual=False, mlp_class_name="LLaMAMoE", intermediate_size=128, rope_interleave=True, rope_adjustments=yarn_config, # YaRN config latent_attention={ "q_lora_rank": 32, "kv_lora_rank": 16, "qk_rope_head_dim": 8, "qk_nope_head_dim": 8, "v_head_dim": 16, }, first_k_dense_replace=3, # Use dense MLP for first 3 layers ) # HF config with YaRN rope_parameters = { "type": "yarn", "rope_theta": 10000.0, "factor": yarn_config["factor"], "beta_fast": yarn_config["beta_fast"], "beta_slow": yarn_config["beta_slow"], "original_max_position_embeddings": yarn_config["original_max_seq_len"], "mscale": yarn_config["mscale"], "mscale_all_dim": yarn_config["mscale_all_dim"], } config_hf = DeepseekV3Config( padded_vocab_size=10000, num_hidden_layers=1, vocab_size=10000, hidden_size=64, intermediate_size=128, num_attention_heads=4, num_key_value_heads=4, q_lora_rank=32, kv_lora_rank=16, qk_rope_head_dim=8, qk_nope_head_dim=8, v_head_dim=16, rope_interleave=True, first_k_dense_replace=3, rms_norm_eps=1e-6, rope_scaling=rope_parameters, # YaRN config ) # Debug: Check if HF config has rope_parameters print("\n=== HF Config Debug ===") print(f"config_hf.rope_parameters: {config_hf.rope_scaling}") block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) model_hf = DeepseekV3ForCausalLM(config_hf).to(device) block_hf = model_hf.model.layers[layer_idx] block_litgpt.eval() block_hf.eval() sync_block_weights(block_litgpt, block_hf) hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) # Prepare RoPE sin/cos tables using YaRN computation from litgpt.model import build_rope_cache rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"] # Build YaRN RoPE cache for LitGPT cos_litgpt, sin_litgpt = build_rope_cache( seq_len=seq_len, n_elem=rope_head_dim, device=device, base=config_litgpt.rope_base, extra_config={ "factor": yarn_config["factor"], "beta_fast": yarn_config["beta_fast"], "beta_slow": yarn_config["beta_slow"], "original_max_seq_len": yarn_config["original_max_seq_len"], "mscale": yarn_config["mscale"], "mscale_all_dim": yarn_config["mscale_all_dim"], }, ) # Get YaRN RoPE embeddings from HF (rotary_emb is on model level, not layer level) rotary_emb = model_hf.model.rotary_emb position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) cos_hf, sin_hf = rotary_emb(hidden_states, position_ids) # Expand dimensions for batch and broadcast cos_litgpt = cos_litgpt.unsqueeze(0).expand(batch_size, -1, -1) sin_litgpt = sin_litgpt.unsqueeze(0).expand(batch_size, -1, -1) # Compare RoPE embeddings first print("\n=== RoPE Embeddings Comparison ===") print(f"LitGPT cos/sin shape: {cos_litgpt.shape}, {sin_litgpt.shape}") print(f"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}") print(f"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}") print(f"Sin max diff: {(sin_litgpt - sin_hf).abs().max()}") print(f"\nLitGPT cos sample [0,0,:4]: {cos_litgpt[0, 0, :4]}") print(f"HF cos sample [0,0,:4]: {cos_hf[0, 0, :4]}") print(f"LitGPT cos min/max: {cos_litgpt.min():.4f} / {cos_litgpt.max():.4f}") print(f"HF cos min/max: {cos_hf.min():.4f} / {cos_hf.max():.4f}") # Check inv_freq from both print("\n=== Checking inv_freq ===") print(f"HF rotary_emb.inv_freq shape: {rotary_emb.inv_freq.shape}") print(f"HF inv_freq: {rotary_emb.inv_freq}") print(f"HF attention_scaling: {rotary_emb.attention_scaling}") # Use the same embeddings for both (LitGPT's) cos = cos_litgpt sin = sin_litgpt causal_mask = torch.triu( torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1 ) attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) # Run forward passes output_litgpt = block_litgpt(hidden_states, cos, sin) output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) if isinstance(output_hf, tuple): output_hf = output_hf[0] max_diff = (output_litgpt - output_hf).abs().max() print("\n=== DEBUG INFO ===") print(f"Max diff: {max_diff}") print(f"Output litgpt mean: {output_litgpt.mean()}, std: {output_litgpt.std()}") print(f"Output hf mean: {output_hf.mean()}, std: {output_hf.std()}") print(f"Cos/sin shape: {cos.shape}, {sin.shape}") print(f"Hidden states shape: {hidden_states.shape}") # Check if the issue is in attention or MLP if hasattr(output_litgpt, "shape") and hasattr(output_hf, "shape"): if output_litgpt.shape != output_hf.shape: print(f"Shape mismatch! litgpt: {output_litgpt.shape}, hf: {output_hf.shape}") assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), f"FAILED: Max diff: {max_diff}" def sync_weights(litgpt_model, hf_model): """Copies weights from lit-gpt model to HF model.""" print("Synchronizing weights...") with torch.no_grad(): hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight) hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight) hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight) hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight) hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight) hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight) hf_model.o_proj.weight.copy_(litgpt_model.proj.weight) print("Synchronization complete.") def sync_block_weights(block_litgpt, block_hf): """Synchronize all weights from LitGPT block to HF block.""" print("Synchronizing block weights...") with torch.no_grad(): # Sync attention weights sync_weights(block_litgpt.attn, block_hf.self_attn) # Sync MLP weights (assumes dense MLP, not MoE) block_hf.mlp.gate_proj.weight.copy_(block_litgpt.mlp.fc_1.weight) block_hf.mlp.up_proj.weight.copy_(block_litgpt.mlp.fc_2.weight) block_hf.mlp.down_proj.weight.copy_(block_litgpt.mlp.proj.weight) # Sync normalization layers block_hf.input_layernorm.weight.copy_(block_litgpt.norm_1.weight) block_hf.post_attention_layernorm.weight.copy_(block_litgpt.norm_2.weight) print("Block synchronization complete.") ================================================ FILE: tutorials/0_to_litgpt.md ================================================ # Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs This tutorial walks you through the main features and usage patterns for ⚡️LitGPT, a library for pretraining, finetuning, and using LLMs that focuses on an efficient user experience while being developer-friendly. The topics, following the installation of LitGPT, are in chronological order, reflecting the steps in an LLM lifecycle: Pretraining → Finetuning → Inference.       However, it is also possible, and even common, to use and deploy models with LitGPT without pretraining and finetuning. So, if you are not interested in pretraining and finetuning, please feel free to skip these sections.   ## Install LitGPT LitGPT is available as a Python library from the PyPI package repository, and we recommend installing it using Python's `pip` installer module, including all required package dependencies: ```bash pip install 'litgpt[all]' ``` Alternatively, if you are a researcher or developer planning to make changes to LitGPT, you can clone the GitHub repository and install it from a local folder as follows: ``` git clone https://github.com/Lightning-AI/litgpt.git cd litgpt pip install -e '.[all]' ```   ## Pretrain LLMs Pretraining LLMs requires substantial compute resources and time commitment. For that reason, most researchers and practitioners prefer to skip this step and continue with the *Download pretrained model weights* section instead. However, if you feel adventurous and want to pretrain your own LLM, here's how. First, we have to decide which type of model architecture we want to use. We list the available architectures by using the `pretrain` command without any additional arguments: ```bash litgpt pretrain list ``` This prints a list of all available model architectures in alphabetical order: ``` Camel-Platypus2-13B Camel-Platypus2-70B CodeLlama-13b-Python-hf ... EleutherAI/pythia-410m ... vicuna-13b-v1.3 vicuna-13b-v1.5 vicuna-13b-v1.5-16k vicuna-33b-v1.3 vicuna-7b-v1.3 vicuna-7b-v1.5 vicuna-7b-v1.5-16k ``` Suppose we want to pretraining the 1.1B parameter small `tiny-llama-1.1b` model. Before starting finetuning, we must also choose and download a tokenizer. We can download a tokenizer via the `download` command. Note that running `litgpt download list` will also print a list of all available models and tokenizers to download. To filter for specific models, e.g., TinyLlama, we can use the `grep` command in our terminal: ```bash litgpt download list | grep TinyLlama ``` This prints ``` TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T TinyLlama/TinyLlama-1.1B-Chat-v1.0 ``` Let's now download the tokenizer corresponding to `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T` that we can then use to pretrain the TinyLlama model: ``` litgpt download \ TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \ --tokenizer_only true ``` (when specified)     Next, we can pretrain the model on the OpenWebText dataset with the default setting as follows: ```bash litgpt pretrain tiny-llama-1.1b \ --data OpenWebText \ --tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T ``` If you are interested in additional settings, you can use the help command as follows: ``` litgpt pretrain --help ```   > [!TIP] > Above, we only covered the most basic commands for pretraining a model using LitGPT. We highly recommend checking the resources below if you are interested in pretraining a model.   **More information and additional resources** - [tutorials/pretrain](./pretrain.md): General information about pretraining in LitGPT - [tutorials/pretrain_tinyllama](./pretrain_tinyllama.md): A tutorial for finetuning a 1.1B TinyLlama model on 3 trillion tokens - [config_hub/pretrain](../config_hub/pretrain): Pre-made config files for pretraining that work well out of the box - Project templates in reproducible environments with multi-GPU and multi-node support: - [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) - [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) - [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)   ## Download pretrained model weights Most practical use cases, like LLM inference (/chat) or finetuning, involve using pretrained model weights. LitGPT supports a large number of model weights, which can be listed by executing the `download` with `list` as an argument: ```bash litgpt download list ``` This will print a (long) list of all supported pretrained models (abbreviated for readability below): ``` .. google/gemma-2b ... meta-llama/Llama-2-7b-hf ... microsoft/phi-2 ... mistralai/Mixtral-8x7B-Instruct-v0.1 ... ``` To download the model weights, provide one of the model strings above as input argument: ```bash litgpt download microsoft/phi-2 ``` ``` model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s] model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s] tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s] ... Converting checkpoint files to LitGPT format. Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin ... Saving converted checkpoint to checkpoints/microsoft/phi-2 ```   > [!TIP] > Note that some models, such as Llama 2, require that you accept Meta AI's terms of service for this model, and you need to use a special access token via the `litgpt download ... --access_token ...` option. For more information, visit the respective Model Hub website, e.g., [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf). The access token can be created under your Model Hub in the `Profile > Access Tokens` menu.   By default, the weights are going to be stored in a `./checkpoints` subdirectory: ```bash ls -lh checkpoints/microsoft/phi-2/ ``` ``` total 11G -rw-r--r-- 1 sebastian sebastian 863 Mar 19 21:14 config.json -rw-r--r-- 1 sebastian sebastian 124 Mar 19 21:14 generation_config.json -rw-r--r-- 1 sebastian sebastian 5.2G Mar 19 21:15 lit_model.pth -rw-r--r-- 1 sebastian sebastian 4.7G Mar 19 21:15 model-00001-of-00002.bin -rw-r--r-- 1 sebastian sebastian 538M Mar 19 21:15 model-00002-of-00002.bin -rw-r--r-- 1 sebastian sebastian 528 Mar 19 21:15 model_config.yaml -rw-r--r-- 1 sebastian sebastian 2.1M Mar 19 21:14 tokenizer.json -rw-r--r-- 1 sebastian sebastian 7.2K Mar 19 21:14 tokenizer_config.json ``` The model is now ready for inference and chat, for example, using the `chat` command on the checkpoint directory: ```bash litgpt chat microsoft/phi-2 ``` ``` Now chatting with phi-2. To exit, press 'Enter' on an empty prompt. Seed set to 1234 >> Prompt: Why are LLMs so useful? >> Reply: When building applications or operating systems, you can use LLMs to know how a computer should respond to your commands. This can make your programs run faster and more efficiently. Time for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens >> Prompt: ```   > [!TIP] > Use `--multiline true` to support prompts that require multiple input lines.
  **More information and additional resources** - [tutorials/download_model_weights](download_model_weights.md): A more comprehensive download tutorial, tips for GPU memory limitations, and more   ## Finetune LLMs LitGPT supports several methods of supervised instruction finetuning, which allows you to finetune models to follow instructions. Datasets for Instruction-finetuning are usually formatted in the following way:     Alternatively, datasets for instruction finetuning can also contain an `'input'` field: In an instruction-finetuning context, "full" finetuning means updating all model parameters as opposed to only a subset. Adapter and LoRA (short for low-rank adaptation) are methods for parameter-efficient finetuning that only require updating a small fraction of the model weights.     Parameter-efficient finetuning is much more resource-efficient and cheaper than full finetuning, and it often results in the same good performance on downstream tasks. In the following example, we will use LoRA for finetuning, which is one of the most popular LLM finetuning methods. (For more information on how LoRA works, please see [Code LoRA from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).) Before we start, we have to download a model as explained in the previous "Download pretrained model" section above: ```bash litgpt download microsoft/phi-2 ``` The LitGPT interface can be used via command line arguments and configuration files. We recommend starting with the configuration files from the [config_hub](../config_hub) and either modifying them directly or overriding specific settings via the command line. For example, we can use the following setting to train the downloaded 2.7B parameter `microsoft/phi-2` model, where we set `--max_steps 5` for a quick test run. If you have downloaded or cloned the LitGPT repository, you can provide the `config` file via a relative path: ```bash litgpt finetune_lora microsoft/phi-2\ --config config_hub/finetune/phi-2/lora.yaml \ --train.max_steps 5 ``` Alternatively, you can provide a URL: ```bash litgpt finetune_lora microsoft/phi-2\ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/phi-2/lora.yaml \ --train.max_steps 5 ```   > [!TIP] > Note that the config file above will finetune the model on the `Alpaca2k` dataset on 1 GPU and save the resulting files in an `out/finetune/lora-phi-2` directory. All of these settings can be changed via a respective command line argument or by changing the config file. > To see more options, execute `litgpt finetune_lora --help`.   Running the previous finetuning command will initiate the finetuning process, which should only take about a minute on a GPU due to the `--train.max_steps 5` setting. ``` {'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'), # TODO 'data': Alpaca2k(mask_prompt=False, val_split_fraction=0.03847, prompt_style=, ignore_index=-100, seed=42, num_workers=4, download_dir=PosixPath('data/alpaca2k')), 'devices': 1, 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100), 'logger_name': 'csv', 'lora_alpha': 16, 'lora_dropout': 0.05, 'lora_head': True, 'lora_key': True, 'lora_mlp': True, 'lora_projection': True, 'lora_query': True, 'lora_r': 8, 'lora_value': True, 'num_nodes': 1, 'out_dir': PosixPath('out/finetune/lora-phi-2'), 'precision': 'bf16-true', 'quantize': None, 'seed': 1337, 'train': TrainArgs(save_interval=800, log_interval=1, global_batch_size=8, micro_batch_size=4, lr_warmup_steps=10, epochs=1, max_tokens=None, max_steps=5, max_seq_length=512, tie_embeddings=None, learning_rate=0.0002, weight_decay=0.0, beta1=0.9, beta2=0.95, max_norm=None, min_lr=6e-05)} Seed set to 1337 Number of trainable parameters: 12,226,560 Number of non-trainable parameters: 2,779,683,840 The longest sequence length in the train data is 512, the model's maximum sequence length is 512 and context length is 2048 Validating ... Recommend a movie for me to watch during the weekend and explain the reason. Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Recommend a movie for me to watch during the weekend and explain the reason. ### Response: I recommend you watch "Parasite" because it's a critically acclaimed movie that won multiple awards, including the Academy Award for Best Picture. It's a thought-provoking and suspenseful film that will keep you on the edge of your seat. The movie also tackles social and economic inequalities, making it a must-watch for anyone interested in meaningful storytelling. /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MeanMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated. warnings.warn(*args, **kwargs) # noqa: B028 Missing logger folder: out/finetune/lora-phi-2/logs/csv Epoch 1 | iter 1 step 0 | loss train: 1.646, val: n/a | iter time: 820.31 ms Epoch 1 | iter 2 step 1 | loss train: 1.660, val: n/a | iter time: 548.72 ms (step) Epoch 1 | iter 3 step 1 | loss train: 1.687, val: n/a | iter time: 300.07 ms Epoch 1 | iter 4 step 2 | loss train: 1.597, val: n/a | iter time: 595.27 ms (step) Epoch 1 | iter 5 step 2 | loss train: 1.640, val: n/a | iter time: 260.75 ms Epoch 1 | iter 6 step 3 | loss train: 1.703, val: n/a | iter time: 568.22 ms (step) Epoch 1 | iter 7 step 3 | loss train: 1.678, val: n/a | iter time: 511.70 ms Epoch 1 | iter 8 step 4 | loss train: 1.741, val: n/a | iter time: 514.14 ms (step) Epoch 1 | iter 9 step 4 | loss train: 1.689, val: n/a | iter time: 423.59 ms Epoch 1 | iter 10 step 5 | loss train: 1.524, val: n/a | iter time: 603.03 ms (step) Training time: 11.20s Memory used: 13.90 GB Saving LoRA weights to 'out/finetune/lora-phi-2/final/lit_model.pth.lora' Saved merged weights to 'out/finetune/lora-phi-2/final/lit_model.pth' ``` Notice that the LoRA script saves both the LoRA weights (`'out/finetune/lora-phi-2/final/lit_model.pth.lora'`) and the LoRA weight merged back into the original model (`'out/finetune/lora-phi-2/final/lit_model.pth'`) for convenience. This allows us to use the finetuned model via the `chat` function directly: ```bash litgpt chat out/finetune/lora-phi-2/final/ ``` ``` Now chatting with phi-2. To exit, press 'Enter' on an empty prompt. Seed set to 1234 >> Prompt: Why are LLMs so useful? >> Reply: LLMs are useful because they can be trained to perform various natural language tasks, such as language translation, text generation, and question-answering. They are also able to understand the context of the input data, which makes them particularly useful for tasks such as sentiment analysis and text summarization. Additionally, because LLMs can learn from large amounts of data, they are able to generalize well and perform well on new data. Time for inference: 2.15 sec total, 39.57 tokens/sec, 85 tokens >> Prompt: ```   **More information and additional resources** - [tutorials/prepare_dataset](prepare_dataset.md): A summary of all out-of-the-box supported datasets in LitGPT and utilities for preparing custom datasets - [tutorials/finetune](finetune.md): An overview of the different finetuning methods supported in LitGPT - [tutorials/finetune_full](finetune_full.md): A tutorial on full-parameter finetuning - [tutorials/finetune_lora](finetune_lora.md): Options for parameter-efficient finetuning with LoRA and QLoRA - [tutorials/finetune_adapter](finetune_adapter.md): A description of the parameter-efficient Llama-Adapter methods supported in LitGPT - [tutorials/oom](oom.md): Tips for dealing with out-of-memory (OOM) errors - [config_hub/finetune](../config_hub/finetune): Pre-made config files for finetuning that work well out of the box   ## LLM inference To use a downloaded or finetuned model for chat, you only need to provide the corresponding checkpoint directory containing the model and tokenizer files. For example, to chat with the phi-2 model from Microsoft, download it as follows, as described in the "Download pretrained model" section: ```bash litgpt download microsoft/phi-2 ``` ``` model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s] model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s] tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s] ... Converting checkpoint files to LitGPT format. Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin ... Saving converted checkpoint to checkpoints/microsoft/phi-2 ``` Then, chat with the model using the following command: ```bash litgpt chat microsoft/phi-2 ``` ``` Now chatting with phi-2. To exit, press 'Enter' on an empty prompt. Seed set to 1234 >> Prompt: What is the main difference between a large language model and a traditional search engine? >> Reply: A large language model uses deep learning algorithms to analyze and generate natural language, while a traditional search engine uses algorithms to retrieve information from web pages. Time for inference: 1.14 sec total, 26.26 tokens/sec, 30 tokens ``` > [!TIP] > Most model weights are already represented in an efficient bfloat16 format. However, if the model currently exceeds your GPU memory, you can try to pass the `--precision bf16-true` option. In addition, you can check the quantization documentation for further optimization, which is linked below.   **More information and additional resources** - [tutorials/inference](inference.md): Chat and inference tutorial - [tutorials/quantize](quantize.md): Quantizing models to reduce GPU memory requirements   ## Using the LitGPT Python API for Inference The previous section explained how to use the `litgpt chat` command line interface for inference. Alternatively, LitGPT also offers a Python API approach to generate text using an LLM: ```python from litgpt import LLM llm = LLM.load("microsoft/phi-2") text = llm.generate("What do Llamas eat?", top_k=1, max_new_tokens=30) print(text) ``` Note that the if you pass a supported model name to `LLM.load()`, as shown above, it will download the model from the HF hub if it doesn't exist locally, yet (use `litgpt download list` on the command line to get a list of all currently supported models.) Alternatively, to load model from a local path, just provide the corresponding path as input to the `load` method: ```python llm = LLM.load("path/to/my/local/checkpoint") ```   **More information and additional resources** - [tutorials/python-api](python-api.md): The LitGPT Python API documentation   ## Evaluating models LitGPT comes with a handy `litgpt evaluate` command to evaluate models with [Eleuther AI's Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). For example, to evaluate the previously downloaded `microsoft/phi-2` model on several tasks available from the Evaluation Harness, you can use the following command: ```bash litgpt evaluate microsoft/phi-2 --batch_size 16 \ --tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" ``` (A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).)   ## Deploy LLMs You can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities: ```bash # 1) Download a pretrained model (alternatively, use your own finetuned model) litgpt download microsoft/phi-2 # 2) Start the server litgpt serve microsoft/phi-2 ``` ```python # 3) Use the server (in a separate session) import requests, json response = requests.post( "http://127.0.0.1:8000/predict", json={"prompt": "Fix typos in the following sentence: Example input"} ) print(response.json()["output"]) ``` This prints: ``` Instruct: Fix typos in the following sentence: Example input Output: Example input. ```   **More information and additional resources** - [tutorials/deploy](deploy.md): A full deployment tutorial and example   ## Converting LitGPT model weights to `safetensors` format Sometimes, it can be useful to convert LitGPT model weights for third-party and external tools. For example, we can convert a LitGPT model to the Hugging Face format and save it via `.safetensors` files, which we can do as follows: ```bash litgpt convert_from_litgpt microsoft/phi-2 out/converted_model/ ``` Certain tools like the `.from_pretrained` method in Hugging Face `transformers` also require the original `config.json` file that originally came with the downloaded model: ```bash cp checkpoints/microsoft/phi-2/config.json out/converted_model/config.json ``` You can now load the model into a Hugging Face transformers model and safe it in a `.safetensors` format as follows: ```bash import torch from transformers import AutoModel # Load model state_dict = torch.load('out/converted_model/model.pth') model = AutoModel.from_pretrained( "microsoft/phi-2", state_dict=state_dict ) # Save .safetensors files model.save_pretrained("out/converted_model/") ``` ``` ⚡ ~/litgpt ls -lh out/converted_model total 16G -rwxr--r-- 1 sebastian sebastian 891 Mar 20 17:08 config.json -rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:08 model-00001-of-00003.safetensors -rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:09 model-00002-of-00003.safetensors -rw-r--r-- 1 sebastian sebastian 601M Mar 20 17:09 model-00003-of-00003.safetensors -rw-r--r-- 1 sebastian sebastian 5.2G Mar 20 16:30 model.pth -rw-r--r-- 1 sebastian sebastian 33K Mar 20 17:09 model.safetensors.index.json ``` You can then use the model with external tools, for example, Eleuther AI's [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) (see the `lm_eval` installation instructions [here](https://github.com/EleutherAI/lm-evaluation-harness?tab=readme-ov-file#install)). The LM Evaluation Harness requires a tokenizer to be present in the model checkpoint folder, which we can copy from the original download checkpoint: ```bash # Copy the tokenizer needed by the Eval Harness cp checkpoints/microsoft/phi-2/tokenizer* out/converted_model ``` Then, we can run the Evaluation Harness as follows: ```bash lm_eval --model hf \ --model_args pretrained="out/converted_model" \ --tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \ --device "cuda:0" \ --batch_size 4 ```   > [!TIP] > The Evaluation Harness tasks above are those used in Open LLM Leaderboard. You can find a list all supported tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).   **More information and additional resources** - [tutorials/convert_lit_models](./convert_lit_models.md): Tutorial on converting LitGPT weights   ## Get involved! We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the [GitHub Issue](https://github.com/Lightning-AI/litgpt/issues) tracker. We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.   > [!TIP] > Unsure about contributing? Check out our [How to Contribute to LitGPT](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) guide.   If you have general questions about building with LitGPT, please [join our Discord](https://discord.gg/VptPCZkGNa). ================================================ FILE: tutorials/convert_hf_checkpoint.md ================================================ # Converting Hugging Face Transformers to LitGPT weights By default, the `litgpt download` command converts the downloaded HF checkpoint files into a LitGPT compatible format after downloading. For example, ```bash litgpt download EleutherAI/pythia-14m ``` creates the following files: ``` checkpoints/ └── EleutherAI/ └── pythia-14m/ ├── config.json ├── generation_config.json ├── model_config.yaml # LitGPT specific file ├── lit_model.pth # LitGPT specific file ├── pytorch_model.bin ├── tokenizer.json └── tokenizer_config.json ``` To disable the automatic conversion, which is useful for development and debugging purposes, you can run the `litgpt download` with the `--convert_checkpoint false` flag. This will only download the checkpoint files but do not convert them for use in LitGPT: ```bash rm -rf checkpoints/EleutherAI/pythia-14m litgpt download EleutherAI/pythia-14m \ --convert_checkpoint false ls checkpoints/EleutherAI/pythia-14m ``` ``` checkpoints/ └── EleutherAI/ └── pythia-14m/ ├── config.json ├── generation_config.json ├── pytorch_model.bin ├── tokenizer.json └── tokenizer_config.json ``` The required files `model_config.yaml` and `lit_model.pth` files can then be manually generated via the `litgpt/scripts/convert_hf_checkpoint.py` script: ```bash litgpt convert_to_litgpt checkpoints/EleutherAI/pythia-14m ``` ================================================ FILE: tutorials/convert_lit_models.md ================================================ ## Converting LitGPT weights to Hugging Face Transformers LitGPT weights need to be converted to a format that Hugging Face understands with a [conversion script](../litgpt/scripts/convert_lit_checkpoint.py) before our scripts can run. We provide a helpful command to convert models LitGPT models back to their equivalent Hugging Face Transformers format: ```bash litgpt convert_from_litgpt checkpoint_dir converted_dir ``` These paths are just placeholders, you will need to customize them based on which finetuning or pretraining command you ran and its configuration. ### Loading converted LitGPT checkpoints into transformers For example, ```bash cp checkpoints/repo_id/config.json converted/config.json ``` Then, you can load the checkpoint file in a Python session as follows: ```python import torch from transformers import AutoModel state_dict = torch.load("output_dir/model.pth") model = AutoModel.from_pretrained( "output_dir/", local_files_only=True, state_dict=state_dict ) ``` Alternatively, you can also load the model without copying the `config.json` file as follows: ```python model = AutoModel.from_pretrained("online_repo_id", state_dict=state_dict) ``` ### Merging LoRA weights Please note that if you want to convert a model that has been finetuned using an adapter like LoRA, these weights should be [merged](../litgpt/scripts/merge_lora.py) to the checkpoint prior to converting. ```sh litgpt merge_lora path/to/lora/checkpoint_dir ```

# A finetuning and conversion tutorial This section contains a reproducible example for finetuning a LitGPT model and converting it back into a HF `transformer` model. 1. Download a model of interest: For convenience, we first specify an environment variable (optional) to avoid copy and pasting the whole path: ```bash export repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T ``` Instead of using TinyLlama, you can replace the `repo_id` target with any other model repository specifier that is currently supported by LitGPT. You can get a list of supported repository specifier by running `litgpt/scripts/download.py` without any additional arguments. Then, we download the model we specified via `$repo_id` above: ```bash litgpt download $repo_id ``` 2. Finetune the model: ```bash export finetuned_dir=out/lit-finetuned-model litgpt finetune_lora $repo_id \ --out_dir $finetuned_dir \ --train.epochs 1 \ --data Alpaca ``` 3. Merge LoRA weights: Note that this step only applies if the model was finetuned with `lora.py` above and not when `full.py` was used for finetuning. ```bash litgpt merge_lora $finetuned_dir/final ``` 4. Convert the finetuning model back into a HF format: ```bash litgpt convert_from_litgpt $finetuned_dir/final/ out/hf-tinyllama/converted ``` 5. Load the model into a `transformers` model: ```python import torch from transformers import AutoModel state_dict = torch.load('out/hf-tinyllama/converted/model.pth') model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", state_dict=state_dict) ```   ## Using the LM Evaluation Harness To evaluate LitGPT models, use the integrated evaluation utilities based on Eleuther AI's LM Evaluation Harness. For more information, please see the [evaluation](evaluation.md) documentation. Alternatively, if you wish to use converted LitGPT models with the LM Evaluation Harness from [Eleuther AI's GitHub repository](https://github.com/EleutherAI/lm-evaluation-harness), you can use the following steps. 1. Follow the instructions above to load the model into a Hugging Face transformers model. 2. Create a `model.safetensor` file: ```python model.save_pretrained("out/hf-tinyllama/converted/") ``` 3. Copy the tokenizer files into the model-containing directory: ```bash cp checkpoints/$repo_id/tokenizer* out/hf-tinyllama/converted ``` 4. Run the evaluation harness, for example: ```bash lm_eval --model hf \ --model_args pretrained=out/hf-tinyllama/converted \ --tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \ --device "cuda:0" \ --batch_size 4 ``` ================================================ FILE: tutorials/deploy.md ================================================ # Serve and Deploy LLMs This document shows how you can serve a LitGPT for deployment.   ## Serve an LLM with LitServe This section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable.   ### Step 1: Start the inference server ```bash # 1) Download a pretrained model (alternatively, use your own finetuned model) litgpt download microsoft/phi-2 # 2) Start the server litgpt serve microsoft/phi-2 ``` > [!TIP] > Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more.   ### Step 2: Query the inference server You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows: ```python import requests, json response = requests.post( "http://127.0.0.1:8000/predict", json={"prompt": "Fix typos in the following sentence: Example input"} ) print(response.json()["output"]) ``` Executing the code above prints the following output: ``` Example input. ```   ### Optional: Use the streaming mode The 2-step procedure described above returns the complete response all at once. If you want to stream the response on a token-by-token basis, start the server with the streaming option enabled: ```bash litgpt serve microsoft/phi-2 --stream true ``` Then, use the following updated code to query the inference server: ```python import requests, json response = requests.post( "http://127.0.0.1:8000/predict", json={"prompt": "Fix typos in the following sentence: Example input"}, stream=True ) # stream the response for line in response.iter_lines(decode_unicode=True): if line: print(json.loads(line)["output"], end="") ``` ``` Sure, here is the corrected sentence: Example input ```   ## Serve an LLM with OpenAI-compatible API LitGPT provides OpenAI-compatible endpoints that allow you to use the OpenAI SDK or any OpenAI-compatible client to interact with your models. This is useful for integrating LitGPT into existing applications that use the OpenAI API.   ### Step 1: Start the server with OpenAI specification ```bash # 1) Download a pretrained model (alternatively, use your own finetuned model) litgpt download HuggingFaceTB/SmolLM2-135M-Instruct # 2) Start the server with OpenAI-compatible endpoints litgpt serve HuggingFaceTB/SmolLM2-135M-Instruct --openai_spec true ``` > [!TIP] > The `--openai_spec true` flag enables OpenAI-compatible endpoints at `/v1/chat/completions` instead of the default `/predict` endpoint.   ### Step 2: Query using OpenAI-compatible endpoints You can now send requests to the OpenAI-compatible endpoint using curl: ```bash curl -X POST http://127.0.0.1:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "SmolLM2-135M-Instruct", "messages": [{"role": "user", "content": "Hello! How are you?"}] }' ``` Or use the OpenAI Python SDK: ```python from openai import OpenAI # Configure the client to use your local LitGPT server client = OpenAI( base_url="http://127.0.0.1:8000/v1", api_key="not-needed" # LitGPT doesn't require authentication by default ) response = client.chat.completions.create( model="SmolLM2-135M-Instruct", messages=[ {"role": "user", "content": "Hello! How are you?"} ] ) print(response.choices[0].message.content) ```   ## Serve an LLM UI with Chainlit If you are interested in developing a simple ChatGPT-like UI prototype, see the Chainlit tutorial in the following Studio: Open In Studio ================================================ FILE: tutorials/developer-docs/README.md ================================================ LitGPT developer documentation files. ================================================ FILE: tutorials/developer-docs/adding-models.md ================================================ # Adding New Models This document provides an overview and explanation of how new LLM architectures and model weights can be added to LitGPT.   > [!NOTE] > One of the design focus areas of LitGPT is to provide efficient readable code. At the same time, LitGPT aims to support selected LLMs that are useful to the community. LitGPT aims to reuse and share as much code as possible between different LLMs to strike a balance between code readability and enabling support for various LLMs. In short, we try to minimize writing custom code for a given LLM and aim for code reuse.     ## 1. Discuss the LLM to be added As an open-source project, we appreciate your contributions! However, before you begin putting valuable time and work into a contribution, ideally, open an issue to discuss whether support for a certain model is within the project's scope.   ## 2. Set up your development environment Clone the repository: ```bash git clone https://github.com/Lightning-AI/litgpt.git ``` Then, install it with the "editable" mode for development: ```bash cd litgpt pip install litgpt -e ".[all]" ```   ## 3. Update the config file Update the [litgpt/config.py](../../litgpt/config.py) config file, adding the new model configuration there. It's easiest to start with the most similar model, copy the configuration, and then modify it according to the `config.json` file on the HF hub. For example, suppose an entry for Llama 3 8B already exists and you want to add support for Llama 3 70B. Copy the Llama 3 8B entry: ```python # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json dict( name="Llama-3-8B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"), vocab_size=128256, padding_multiple=64, n_layer=32, n_head=32, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=14336, rope_base=500000, ), ``` Then create the entry for the 70B model. Here, make sure you update the values according to the `config.json` file available on the HF hub: ```python # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json dict( name="Llama-3-70B{}", hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"), vocab_size=128256, padding_multiple=64, n_layer=80, n_head=64, n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, ), ```   > [!NOTE] > Some models may require you to implement a new MLP class analogous to `class LLaMAMLP`. > A more or less reliable indicator is the presence of a `modeling.py` file in the model's original repository. > If this file exists, it suggests that this model requires custom code. > This will then also require additional changes beyond simply updating > the configuration in LitGPT's `config.py`.   ## 4. Try downloading the model After making the modifications above, try downloading the model: ```bash litgpt download meta-llama/Meta-Llama-3-70B --access_token ... ```   > [!NOTE] > Not all models require an access token   If the conversion following the download fails, proceed with the next section.   ## 5. Update the checkpoint conversion script If the `litgpt download ...` command from the previous section failed, you may have to adjust the checkpoint conversion script: [litgpt/scripts/convert_hf_checkpoint.py](../../litgpt/scripts/convert_hf_checkpoint.py). Here, you may have to adjust or implement a new `def copy_weights_hf_...` function. You can test the updated conversion code without needing to redownload the weights as follows: ```bash python litgpt/scripts/convert_hf_checkpoint.py meta-llama/Meta-Llama-3-70B ```   ## 6. Add the Prompt Style If you are adding a new model class, find out its prompt style. First, check [litgpt/prompts.py](../../litgpt/prompts.py) if a similar prompt style template already exists. For Llama 3, this is as follows: ```python class Llama3(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229 return ( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" "You are a helpful assistant.<|eot_id|>\n" # The system prompt is optional "<|start_header_id|>user<|end_header_id|>\n\n" f"{prompt}<|eot_id|>\n" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: return ( [tokenizer.eos_id], [tokenizer.token_to_id("<|eot_id|>")], ) ``` If your model requires a different prompt template, create a new `PromptStyle` class. Then, in the same file, update the `prompt_styles` dictionary: ```python prompt_styles: Dict[str, Type[PromptStyle]] = { ... "llama3": Llama3, } ``` Finally, also in the same file, update the `model_name_to_prompt_style` function: ```python def model_name_to_prompt_style(model_name: str) -> PromptStyle: ... if re.search("Llama-3.*-Instruct", model_name): return Llama3() ```   ## 7. Try using the model for inference Next, use the model to see if inference works: ```bash litgpt generate meta-llama/Meta-Llama-3-70B ```   > [!NOTE] > If you notice that the model produces non-sensible language outputs, you need to double-check the config file and find out if there are incorrect values or other problems. The next section on adding unit tests may offer additional pointers.     ## 8. Add unit tests   ### 8.1 Add model unit tests Open the [`tests/test_model.py`](../../tests/test_model.py) file and add a new `def test_against_hf_...` function using one of the existing functions as a template. For instance, ```python def test_against_hf_llama2(ours_kwargs, device, dtype): ... # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) ``` If the ```bash litgpt generate meta-llama/Meta-Llama-3-70B ``` command from the previous section produces incoherent text, this function can be a helpful guide for debugging. For this, modify the implementation in `transformers` and `litgpt` packages (on your local installation), to inspect or print out the intermediate values at a layer. It's recommend starting with the embedding layers and then go through one layer at the time, to find out where the values differ to get pointers for debugging. Test the unit test via ```python pytest tests/test_model.py::test_against_hf_... ```   ### 8.2 Add prompt style unit test Open the [`tests/test_model.py`](../../tests/test_model.py) file and add a test for the respective prompts you added earlier, if applicable. For example, ```python def test_prompt_style_from_config(): model_names = [ ... "Llama-3-70B-Instruct", ... ] ``` Run the unit test via ```python pytest tests/test_prompts.py ```   ## 9. Try finetuning the model Now, try finetuning the model: ```bash litgpt finetune meta-llama/Meta-Llama-3-70B --train.max_steps 10 ```   ## 10. Update the documentation Finally, update the documentation files.   ### 10.1 Update the README file Update the "All Models" table in the [README.md](../../README.md) file.   ### 10.2 Update the download tutorials Add the new model to the model table at the top as well as to the list under `litgpt download list`. ================================================ FILE: tutorials/developer-docs/python-api.md ================================================ # LitGPT High-level Python API This is a work-in-progress draft for a high-level LitGPT Python API.   ## Model loading & saving The `LLM.load` command loads an `llm` object, which contains both the model object (a PyTorch module) and a preprocessor. ```python from litgpt import LLM llm = LLM.load( model="url | local_path", # high-level user only needs to care about those: memory_reduction="none | medium | strong" # advanced options for technical users: source="hf | local | other" quantize="bnb.nf4", precision="bf16-true", device=""auto | cuda | cpu", ) ``` Here, - `llm.model` contains the PyTorch Module - and `llm.preprocessor.tokenizer` contains the tokenizer The `llm.save` command saves the model weights, tokenizer, and configuration information. ```python llm.save(checkpoint_dir, format="lightning | ollama | hf") ```   ## Inference / Chat ``` response = llm.generate( prompt="What do Llamas eat?", temperature=0.1, top_p=0.8, ... ) ```   ## Dataset The `llm.prepare_dataset` command prepares a dataset for training. ``` llm.download_dataset( URL, ... ) ``` ``` dataset = llm.prepare_dataset( path, task="pretrain | instruction_finetune", test_portion=0.1, ... ) ```   ## Training ```python llm.instruction_finetune( config=None, dataset=dataset, max_iter=10, method="full | lora | adapter | adapter_v2" ) ``` ```python llm.pretrain(config=None, dataset=dataset, max_iter=10, ...) ```   ## Serving ```python llm.serve(port=8000) ``` Then in another Python session: ```python import requests, json response = requests.post( "http://127.0.0.1:8000/predict", json={"prompt": "Fix typos in the following sentence: Example input"} ) print(response.json()["output"]) ``` ================================================ FILE: tutorials/download_model_weights.md ================================================ # Download Model Weights with LitGPT LitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the `litgpt download list` command.   | Model | Model size | Author | Reference | |----|----|----|----| | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | | Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) | | Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786) | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) | | Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | | Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) | | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) | | Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743) | | Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233) | | Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) | | Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) | | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) | | QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | | Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | |   ## General Instructions ### 1. List Available Models To see all supported models, run the following command: ```bash litgpt download list ``` The output is shown below: ``` allenai/OLMo-1B-hf allenai/OLMo-7B-hf allenai/OLMo-7B-Instruct-hf bsc-lt/salamandra-2b bsc-lt/salamandra-2b-instruct bsc-lt/salamandra-7b bsc-lt/salamandra-7b-instruct codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf codellama/CodeLlama-34b-hf codellama/CodeLlama-34b-Instruct-hf codellama/CodeLlama-34b-Python-hf codellama/CodeLlama-70b-hf codellama/CodeLlama-70b-Instruct-hf codellama/CodeLlama-70b-Python-hf codellama/CodeLlama-7b-hf codellama/CodeLlama-7b-Instruct-hf codellama/CodeLlama-7b-Python-hf databricks/dolly-v2-12b databricks/dolly-v2-3b databricks/dolly-v2-7b deepseek-ai/DeepSeek-R1-Distill-Llama-8B deepseek-ai/DeepSeek-R1-Distill-Llama-70B EleutherAI/pythia-1.4b EleutherAI/pythia-1.4b-deduped EleutherAI/pythia-12b EleutherAI/pythia-12b-deduped EleutherAI/pythia-14m EleutherAI/pythia-160m EleutherAI/pythia-160m-deduped EleutherAI/pythia-1b EleutherAI/pythia-1b-deduped EleutherAI/pythia-2.8b EleutherAI/pythia-2.8b-deduped EleutherAI/pythia-31m EleutherAI/pythia-410m EleutherAI/pythia-410m-deduped EleutherAI/pythia-6.9b EleutherAI/pythia-6.9b-deduped EleutherAI/pythia-70m EleutherAI/pythia-70m-deduped garage-bAInd/Camel-Platypus2-13B garage-bAInd/Camel-Platypus2-70B garage-bAInd/Platypus-30B garage-bAInd/Platypus2-13B garage-bAInd/Platypus2-70B garage-bAInd/Platypus2-70B-instruct garage-bAInd/Platypus2-7B garage-bAInd/Stable-Platypus2-13B google/codegemma-7b-it google/gemma-3-27b-it google/gemma-3-12b-it google/gemma-3-4b-it google/gemma-3-1b-it google/gemma-2-27b google/gemma-2-27b-it google/gemma-2-2b google/gemma-2-2b-it google/gemma-2-9b google/gemma-2-9b-it google/gemma-2b google/gemma-2b-it google/gemma-7b google/gemma-7b-it h2oai/h2o-danube2-1.8b-chat HuggingFaceTB/SmolLM2-135M HuggingFaceTB/SmolLM2-135M-Instruct HuggingFaceTB/SmolLM2-360M HuggingFaceTB/SmolLM2-360M-Instruct HuggingFaceTB/SmolLM2-1.7B HuggingFaceTB/SmolLM2-1.7B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3 lmsys/vicuna-13b-v1.5 lmsys/vicuna-13b-v1.5-16k lmsys/vicuna-33b-v1.3 lmsys/vicuna-7b-v1.3 lmsys/vicuna-7b-v1.5 lmsys/vicuna-7b-v1.5-16k meta-llama/Llama-2-13b-chat-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-chat-hf meta-llama/Llama-2-70b-hf meta-llama/Llama-2-7b-chat-hf meta-llama/Llama-2-7b-hf meta-llama/Llama-3.2-1B meta-llama/Llama-3.2-1B-Instruct meta-llama/Llama-3.2-3B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-3.3-70B-Instruct meta-llama/Meta-Llama-3-70B meta-llama/Meta-Llama-3-70B-Instruct meta-llama/Meta-Llama-3-8B meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Meta-Llama-3.1-405B meta-llama/Meta-Llama-3.1-405B-Instruct meta-llama/Meta-Llama-3.1-70B meta-llama/Meta-Llama-3.1-70B-Instruct meta-llama/Meta-Llama-3.1-8B meta-llama/Meta-Llama-3.1-8B-Instruct microsoft/phi-1_5 microsoft/phi-2 microsoft/Phi-3-mini-128k-instruct microsoft/Phi-3-mini-4k-instruct microsoft/Phi-3.5-mini-instruct microsoft/phi-4 microsoft/Phi-4-mini-instruct mistralai/mathstral-7B-v0.1 mistralai/Mistral-7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 mistralai/Mistral-7B-Instruct-v0.3 mistralai/Mistral-7B-v0.1 mistralai/Mistral-7B-v0.3 mistralai/Mistral-Large-Instruct-2407 mistralai/Mistral-Large-Instruct-2411 mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-v0.1 mistralai/Mixtral-8x22B-Instruct-v0.1 mistralai/Mixtral-8x22B-v0.1 NousResearch/Nous-Hermes-13b NousResearch/Nous-Hermes-llama-2-7b NousResearch/Nous-Hermes-Llama2-13b nvidia/Llama-3.1-Nemotron-70B-Instruct-HF openlm-research/open_llama_13b openlm-research/open_llama_3b openlm-research/open_llama_7b Qwen/Qwen2.5-0.5B Qwen/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-1.5B Qwen/Qwen2.5-1.5B-Instruct Qwen/Qwen2.5-3B Qwen/Qwen2.5-3B-Instruct Qwen/Qwen2.5-7B Qwen/Qwen2.5-7B-Instruct Qwen/Qwen2.5-7B-Instruct-1M Qwen/Qwen2.5-14B Qwen/Qwen2.5-14B-Instruct Qwen/Qwen2.5-14B-Instruct-1M Qwen/Qwen2.5-32B Qwen/Qwen2.5-32B-Instruct Qwen/Qwen2.5-72B Qwen/Qwen2.5-72B-Instruct Qwen/Qwen2.5-Coder-0.5B Qwen/Qwen2.5-Coder-0.5B-Instruct Qwen/Qwen2.5-Coder-1.5B Qwen/Qwen2.5-Coder-1.5B-Instruct Qwen/Qwen2.5-Coder-3B Qwen/Qwen2.5-Coder-3B-Instruct Qwen/Qwen2.5-Coder-7B Qwen/Qwen2.5-Coder-7B-Instruct Qwen/Qwen2.5-Coder-14B Qwen/Qwen2.5-Coder-14B-Instruct Qwen/Qwen2.5-Coder-32B Qwen/Qwen2.5-Coder-32B-Instruct Qwen/Qwen2.5-Math-1.5B Qwen/Qwen2.5-Math-1.5B-Instruct Qwen/Qwen2.5-Math-7B Qwen/Qwen2.5-Math-7B-Instruct Qwen/Qwen2.5-Math-72B Qwen/Qwen2.5-Math-72B-Instruct Qwen/Qwen3-0.6B Qwen/Qwen3-0.6B-Base Qwen/Qwen3-1.7B Qwen/Qwen3-1.7B-Base Qwen/Qwen3-4B Qwen/Qwen3-4B-Base Qwen/Qwen3-8B Qwen/Qwen3-8B-Base Qwen/Qwen3-14B Qwen/Qwen3-14B-Base Qwen/Qwen3-32B Qwen/Qwen3-30B-A3B Qwen/Qwen3-30B-A3B-Base Qwen/Qwen3-235B-A22B Qwen/Qwen3-4B-Thinking-2507 Qwen/Qwen3-4B-Instruct-2507 Qwen/Qwen3-30B-A3B-Thinking-2507 Qwen/Qwen3-30B-A3B-Instruct-2507 Qwen/Qwen3-235B-A22B-Thinking-2507 Qwen/Qwen3-235B-A22B-Instruct-2507 Qwen/QwQ-32B Qwen/QwQ-32B-Preview stabilityai/FreeWilly2 stabilityai/stable-code-3b stabilityai/stablecode-completion-alpha-3b stabilityai/stablecode-completion-alpha-3b-4k stabilityai/stablecode-instruct-alpha-3b stabilityai/stablelm-3b-4e1t stabilityai/stablelm-base-alpha-3b stabilityai/stablelm-base-alpha-7b stabilityai/stablelm-tuned-alpha-3b stabilityai/stablelm-tuned-alpha-7b stabilityai/stablelm-zephyr-3b tiiuae/falcon-180B tiiuae/falcon-180B-chat tiiuae/falcon-40b tiiuae/falcon-40b-instruct tiiuae/falcon-7b tiiuae/falcon-7b-instruct tiiuae/Falcon3-1B-Base tiiuae/Falcon3-1B-Instruct tiiuae/Falcon3-3B-Base tiiuae/Falcon3-3B-Instruct tiiuae/Falcon3-7B-Base tiiuae/Falcon3-7B-Instruct tiiuae/Falcon3-10B-Base tiiuae/Falcon3-10B-Instruct TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T togethercomputer/LLaMA-2-7B-32K togethercomputer/RedPajama-INCITE-7B-Base togethercomputer/RedPajama-INCITE-7B-Chat togethercomputer/RedPajama-INCITE-7B-Instruct togethercomputer/RedPajama-INCITE-Base-3B-v1 togethercomputer/RedPajama-INCITE-Base-7B-v0.1 togethercomputer/RedPajama-INCITE-Chat-3B-v1 togethercomputer/RedPajama-INCITE-Chat-7B-v0.1 togethercomputer/RedPajama-INCITE-Instruct-3B-v1 togethercomputer/RedPajama-INCITE-Instruct-7B-v0.1 Trelis/Llama-2-7b-chat-hf-function-calling-v2 unsloth/Mistral-7B-v0.2 ```   > [!TIP] > To sort the list above by model name after the `/`, use `litgpt download list | sort -f -t'/' -k2`.   > [!NOTE] > If you want to adopt a model variant that is not listed in the table above but has a similar architecture as one of the supported models, you can use this model by by using the `--model_name` argument as shown below: > > ```bash > litgpt download NousResearch/Hermes-2-Pro-Mistral-7B \ > --model_name Mistral-7B-v0.1 > ```   ### 2. Download Model Weights To download the weights for a specific model provide a `` with the model's repository ID. For example: ```bash litgpt download ``` This command downloads the model checkpoint into the `checkpoints/` directory.   ### 3. Additional Help For more options, add the `--help` flag when running the script: ```bash litgpt download --help ```   ### 4. Run the Model After conversion, run the model with the given checkpoint path as input, adjusting `repo_id` accordingly: ```bash litgpt chat ```   ## Tinyllama Example This section shows a typical end-to-end example for downloading and using TinyLlama: 1. List available TinyLlama checkpoints: ```bash litgpt download list | grep Tiny ``` ``` TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T TinyLlama/TinyLlama-1.1B-Chat-v1.0 ``` 2. Download a TinyLlama checkpoint: ```bash export repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T litgpt download $repo_id ``` 3. Use the TinyLlama model: ```bash litgpt chat $repo_id ```   ## Specific models and access tokens Note that certain models require that you've been granted access to the weights on the Hugging Face Hub. For example, to get access to the Gemma 2B model, you can do so by following the steps at . After access is granted, you can find your HF hub token in . Once you've been granted access and obtained the access token you need to pass the additional `--access_token`: ```bash litgpt download google/gemma-2b \ --access_token your_hf_token ```   ## Finetunes and Other Model Variants Sometimes you want to download the weights of a finetune of one of the models listed above. To do this, you need to manually specify the `model_name` associated to the config to use. For example: ```bash litgpt download NousResearch/Hermes-2-Pro-Mistral-7B \ --model_name Mistral-7B-v0.1 ```   ## Tips for GPU Memory Limitations The `litgpt download` command will automatically convert the downloaded model checkpoint into a LitGPT-compatible format. In case this conversion fails due to GPU memory constraints, you can try to reduce the memory requirements by passing the `--dtype bf16-true` flag to convert all parameters into this smaller precision (however, note that most model weights are already in a bfloat16 format, so it may not have any effect): ```bash litgpt download --dtype bf16-true ``` (If your GPU does not support the bfloat16 format, you can also try a regular 16-bit float format via `--dtype 16-true`.)   ## Converting Checkpoints Manually For development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps. You can do this by passing the `--convert_checkpoint false` option to the download script: ```bash litgpt download \ --convert_checkpoint false ``` and then calling the `convert_hf_checkpoint` command: ```bash litgpt convert_to_litgpt ```   ## Downloading Tokenizers Only In some cases we don't need the model weight, for example, when we are pretraining a model from scratch instead of finetuning it. For cases like this, you can use the `--tokenizer_only` flag to only download a model's tokenizer, which can then be used in the pretraining scripts: ```bash litgpt download TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \ --tokenizer_only true ``` and ```bash litgpt pretrain tiny-llama-1.1b \ --data ... \ --tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/ ``` ================================================ FILE: tutorials/evaluation.md ================================================ # LLM Evaluation   ## Using lm-evaluation-harness You can evaluate LitGPT using [EleutherAI's lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) framework with a large number of different evaluation tasks. You need to install the `lm-eval` framework first: ```bash pip install lm_eval ```   ### Evaluating LitGPT base models Suppose you downloaded a base model that we want to evaluate. Here, we use the `microsoft/phi-2` model: ```bash litgpt download microsoft/phi-2 ``` The download command above will save the model to the `checkpoints/microsoft/phi-2` directory, which we can specify in the following evaluation command: ``` litgpt evaluate microsoft/phi-2/ \ --batch_size 4 \ --tasks "hellaswag,truthfulqa_mc2,mmlu" \ --out_dir evaluate_model/ ``` The resulting output is as follows: ``` ... |---------------------------------------|-------|------|-----:|--------|-----:|---|-----:| ... |truthfulqa_mc2 | 2|none | 0|acc |0.4656|± |0.0164| |hellaswag | 1|none | 0|acc |0.2569|± |0.0044| | | |none | 0|acc_norm|0.2632|± |0.0044| | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.2434|± |0.0036| | - humanities |N/A |none | 0|acc |0.2578|± |0.0064| | - other |N/A |none | 0|acc |0.2401|± |0.0077| | - social_sciences|N/A |none | 0|acc |0.2301|± |0.0076| | - stem |N/A |none | 0|acc |0.2382|± |0.0076| ``` Please note that the `litgpt evaluate` command run an internal model conversion. This is only necessary the first time you want to evaluate a model, and it will skip the conversion steps if you run the `litgpt evaluate` on the same checkpoint directory again. In some cases, for example, if you modified the model in the `checkpoint_dir` since the first `litgpt evaluate` call, you need to use the `--force_conversion` flag to to update the files used by litgpt evaluate accordingly: ``` litgpt evaluate microsoft/phi-2/ \ --batch_size 4 \ --out_dir evaluate_model/ \ --tasks "hellaswag,truthfulqa_mc2,mmlu" \ --force_conversion true ```   > [!TIP] > Run `litgpt evaluate list` to print a list > of the supported tasks. To filter for a specific subset of tasks, e.g., MMLU, use `litgpt evaluate list | grep mmlu`. > [!TIP] > The evaluation may take a long time, and for testing purpoes, you may want to reduce the number of tasks > or set a limit for the number of examples per task, for example, `--limit 10`.   ### Evaluating LoRA-finetuned LLMs No further conversion is necessary when evaluating LoRA-finetuned models as the `finetune_lora` command already prepares the necessary merged model files: ```bash litgpt finetune_lora microsoft/phi-2 \ --out_dir lora_model ```   ```bash litgpt evaluate lora_model/final \ --batch_size 4 \ --tasks "hellaswag,truthfulqa_mc2,mmlu" \ --out_dir evaluate_model/ \ ```   ### Evaluating on a custom test set There is currently no built-in function to evaluate models on custom test sets. However, this section describes a general approach that users can take to evaluate the responses of a model using another LLM. Suppose you have a test dataset with the following structure: ```python test_data = [ { "instruction": "Name the author of 'Pride and Prejudice'.", "input": "", "output": "Jane Austen." }, { "instruction": "Pick out the adjective from the following list.", "input": "run, tall, quickly", "output": "The correct adjective from the list is 'tall.'" }, ] ``` For simplicity, the dictionary above only contains two entries. In practice, it is recommended to use test datasets that contain at least 100 entries (ideally 1000 or more). If your dataset is stored in JSON format, use the following code to load it: ```python with open("test_data.json", "r") as file: test_data = json.load(file) ``` Next, it is recommended to format the dataset according to a prompt style. For example, to use the `Alpaca` prompt style, use the following code: ```python from litgpt.prompts import Alpaca prompt_style = Alpaca() prompt_style.apply(prompt=test_data[0]["instruction"], **test_data[0]) ``` which returns ``` "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nName the author of 'Pride and Prejudice'.\n\n### Response:\n ``` Next, load the LLM you want to evaluate. For this example, we use `phi-2`: ```python from litgpt import LLM llm = LLM.load("microsoft/phi-2") ``` Then, using the loaded model, we add the test set responses to the dataset: ```python from tqdm import trange for i in trange(len(test_data)): response = llm.generate(prompt_style.apply(prompt=test_data[i]["instruction"], **test_data[i])) test_data[i]["response"] = response ``` Next, we use a second LLM to calculate the response quality on a scale from 0 to 100. It is recommended to use the 70B Llama 3 instruction-fintuned model for this task, or the smaller 8B Llama 3 model, which is more resource-efficient: ```python del llm # delete previous `llm` to free up GPU memory scorer = LLM.load("meta-llama/Meta-Llama-3-8B-Instruct", access_token="...") ``` Then, based on this LLM, we calculate the response quality with the following function: ```python from tqdm import tqdm def generate_model_scores(data_dict, model, response_field="response", target_field="output"): scores = [] for entry in tqdm(data_dict, desc="Scoring entries"): prompt = ( f"Given the input `{format_input(entry)}` " f"and correct output `{entry[target_field]}`, " f"score the model response `{entry[response_field]}`" f" on a scale from 0 to 100, where 100 is the best score. " f"Respond with the integer number only." ) score = model.generate(prompt, max_new_tokens=50) try: scores.append(int(score)) except ValueError: continue return scores ``` ```python scores = generate_model_scores(test_data, model=scorer) print(f"\n{llm}") print(f"Number of scores: {len(scores)} of {len(test_data)}") print(f"Average score: {sum(scores)/len(scores):.2f}\n") ``` This will print out the average score on all test set entries: ``` Scoring entries: 100%|██████████| 2/2 [00:00<00:00, 4.37it/s] Number of scores: 2 of 2 Average score: 47.50 ``` ================================================ FILE: tutorials/examples/ptl-trainer/README.md ================================================ ## Minimal PyTorch Lightning Trainer Example The script in this folder provides minimal examples showing how to train a LitGPT model using LitGPT's `GPT` class with the [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) Trainer. You can run the scripts as follows:   ## Small 160M model: ```bash # Download the Pythia model litgpt download EleutherAI/pythia-160m python litgpt_ptl_small.py ```   ## Medium-sized 8B model: ```bash # Download the Llama 3.1 model litgpt download meta-llama/Meta-Llama-3.1-8B --access_token hf_... python litgpt_ptl_medium.py ``` ================================================ FILE: tutorials/examples/ptl-trainer/litgpt_ptl_medium.py ================================================ import lightning as L import torch import litgpt from litgpt.data import Alpaca2k from litgpt.lora import GPT, merge_lora_weights class LitLLM(L.LightningModule): def __init__(self): super().__init__() self.model = GPT.from_name( name="Llama-3.1-8B", lora_r=32, lora_alpha=16, lora_dropout=0.05, lora_key=False, lora_value=True, ) litgpt.lora.mark_only_lora_as_trainable(self.model) def on_train_start(self): state_dict = torch.load("checkpoints/meta-llama/Meta-Llama-3.1-8B/lit_model.pth", mmap=True) self.model.load_state_dict(state_dict, strict=False) def training_step(self, batch): input_ids, targets = batch["input_ids"], batch["labels"] logits = self.model(input_ids) loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self): warmup_steps = 10 optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) return [optimizer], [scheduler] if __name__ == "__main__": data = Alpaca2k() tokenizer = litgpt.Tokenizer("checkpoints/meta-llama/Meta-Llama-3.1-8B") data.connect(tokenizer, batch_size=1, max_seq_length=512) trainer = L.Trainer( devices=1, max_epochs=2, accumulate_grad_batches=8, precision="bf16-true", ) with trainer.init_module(empty_init=True): model = LitLLM() trainer.fit(model, data) # Save final checkpoint merge_lora_weights(model.model) trainer.save_checkpoint("checkpoints/finetuned.ckpt", weights_only=True) ================================================ FILE: tutorials/examples/ptl-trainer/litgpt_ptl_small.py ================================================ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import lightning as L import torch from litgpt import LLM from litgpt.data import Alpaca2k class LitLLM(L.LightningModule): def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None): super().__init__() self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None) self.trainer_ckpt_path = trainer_ckpt_path def setup(self, stage): self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path) def training_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("validation_loss", loss, prog_bar=True) return loss def configure_optimizers(self): warmup_steps = 10 optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) return [optimizer], [scheduler] if __name__ == "__main__": batch_size = 8 accumulate_grad_batches = 1 ######################################################### # Use case 1: Pretraining from random weights ######################################################### llm = LLM.load("EleutherAI/pythia-160m", tokenizer_dir="EleutherAI/pythia-160m", init="random") llm.save("pythia-160m-random-weights") del llm lit_model = LitLLM(checkpoint_dir="pythia-160m-random-weights", tokenizer_dir="EleutherAI/pythia-160m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") del lit_model ############################################################################# # Use case 2: Continued pretraining / finetuning from downloaded checkpoint ############################################################################# lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") del lit_model ######################################################### # Use case 3: Resume training from Trainer checkpoint ######################################################### import os def find_latest_checkpoint(directory): latest_checkpoint = None latest_time = 0 for root, _, files in os.walk(directory): for file in files: if file.endswith(".ckpt"): file_path = os.path.join(root, file) file_time = os.path.getmtime(file_path) if file_time > latest_time: latest_time = file_time latest_checkpoint = file_path return latest_checkpoint lit_model = LitLLM( checkpoint_dir="EleutherAI/pythia-160m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs") ) data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ################################################################# # Use case 4: Resume training after saving a checkpoint manually ################################################################# lit_model.llm.save("finetuned_checkpoint") del lit_model lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint") data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ================================================ FILE: tutorials/finetune.md ================================================ # Finetuning We provide a simple finetuning commands (`litgpt finetune_*`) that instruction-finetune a pretrained model on datasets such as [Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [Dolly](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm), and others. For more information on the supported instruction datasets and how to prepare your own custom datasets, please see the [tutorials/prepare_dataset](prepare_dataset.md) tutorials. LitGPT currently supports the following finetuning methods: ```bash litgpt finetune_full litgpt finetune_lora litgpt finetune_adapter litgpt finetune_adapter_v2 ```   > [!TIP] > To install all required dependencies before finetuning, first run `pip install "litgpt[all]"`.   The following section provides more details about these methods, including links for additional resources.   ## LitGPT finetuning commands The section below provides additional information on the available and links to further resources.   ### Full finetuning ```bash litgpt finetune_full ``` This method trains all model weight parameters and is the most memory-intensive finetuning technique in LitGPT. **More information and resources:** - the LitGPT [tutorials/finetune_full](finetune_full.md) tutorial   ### LoRA and QLoRA finetuning ```bash litgpt finetune_lora stabilityai/stablelm-base-alpha-3b ``` LoRA and QLoRA are parameter-efficient finetuning technique that only require updating a small number of parameters, which makes this a more memory-efficienty alternative to full finetuning. **More information and resources:** - the LitGPT [tutorials/finetune_lora](finetune_lora.md) tutorial - the LoRA paper by ([Hu et al. 2021](https://arxiv.org/abs/2106.09685)) - the conceptual tutorial [Parameter-Efficient LLM Finetuning With Low-Rank Adaptation (LoRA)](https://lightning.ai/pages/community/tutorial/lora-llm/)   ### Adapter finetuning ```bash litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b ``` or ```bash litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b ``` Similar to LoRA, adapter finetuning is a parameter-efficient finetuning technique that only requires training a small subset of weight parameters, making this finetuning method more memory-efficient than full-parameter finetuning. **More information and resources:** - the LitGPT [tutorials/finetune_adapter](finetune_adapter.md) tutorial - the Llama-Adapter ([Gao et al. 2023](https://arxiv.org/abs/2304.15010)) and Llama-Adapter v2 ([Zhang et al. 2023](https://arxiv.org/abs/2303.16199)) papers that originally introduces these methods - the conceptual tutorial [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) ================================================ FILE: tutorials/finetune_adapter.md ================================================ # Finetuning with Adapter Adapter, first introduced for the LLaMA model as [LLaMA-Adapter](https://arxiv.org/abs/2303.16199), is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in an LLM. In total, there are only ~500k parameters to update during finetuning in StableLM 3B, which significantly reduces the memory footprint and speeds up training. We are able to demonstrate instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3060 GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour. If you are new to Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful. LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the StableLM 3B model has 500k trainable parameters with GPT v1, GPT-Adapter v2 adds an additional 1.5 M trainable parameter for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~2.3 M trainable parameters in total. ## Preparation The steps here only need to be done once: 1. Follow the instructions in the [README](../README.md) to install the dependencies. 2. Download and convert the weights following our [guide](download_model_weights.md). LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more. You can optionally [prepare your own dataset](#tune-on-your-dataset). For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial. ## Running the finetuning ```bash litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \ --data Alpaca \ ``` or for Adapter V2 ```bash litgpt finetune adapter_v2 stabilityai/stablelm-base-alpha-3b \ --data Alpaca \ ``` The finetuning requires at least one GPU with ~12 GB memory. You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available. Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently. To fit Adapter V2 to 12GB memory set `--train.micro_batch_size 2`. For example, the following settings will let you finetune the model in under 1 hour: ```bash --devices 4 --train.micro_batch_size 4 ``` This script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name: ```bash litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \ --data Alpaca \ --out_dir out/adapter/my-model-finetuned ``` or for Adapter V2 ```bash litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \ --data Alpaca \ --out_dir out/adapter_v2/my-model-finetuned ``` If your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument. For instance, to fine-tune on MPS (the GPU on modern Macs), you can run ```bash litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \ --data Alpaca \ --out_dir out/adapter/my-model-finetuned \ --precision 32-true ``` Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac. ### Quantization Optionally, finetuning using quantization can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type: ```bash litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \ --quantize "bnb.nf4" ``` or using `adapter_v2` with double-quantization: ```bash litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \ --quantize "bnb.nf4-dq" ``` For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md). ## Test the model You can test the finetuned model with your own instructions by running: ```bash litgpt generate_adapter stabilityai/stablelm-base-alpha-3b \ --prompt "Recommend a movie to watch on the weekend." ``` or for Adapter V2 ```bash litgpt generate_adapter_v2 stabilityai/stablelm-base-alpha-3b \ --prompt "Recommend a movie to watch on the weekend." ``` Output: ```text A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy... ``` If your GPU supports `bfloat16`, the script will automatically use it. ## Tune on your dataset You can easily train on your own instruction dataset saved in JSON format. 1. Create a JSON file in which each row holds one instruction-response pair. A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context. Below is an example json file: ```text [ { "instruction": "Arrange the given numbers in ascending order.", "input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template "output": "0, 2, 3, 4, 8" }, ... ] ``` 2. Run `litgpt adapter` or `litgpt adapter_v2` by passing in the location of your data (and optionally other parameters): ```bash litgpt finetune_adapter tiiuae/falcon-7b \ --data JSON \ --data.json_path data/mydata.json \ --out_dir data/mydata-finetuned ``` ================================================ FILE: tutorials/finetune_full.md ================================================ # Finetuning the whole model If you are interested in parameter-efficient finetuning, check out [finetune_adapter.md](finetune_adapter.md). In contrast to parameter-efficient finetuning, this "full" approach finetunes all model parameters, which is substantially more expensive. It may only be recommended as a baseline for comparison studies. ## Preparation The steps here only need to be done once: 1. Follow the instructions in the [README](../README.md) to install the dependencies. 2. Download and convert the weights following our [guide](download_model_weights.md). LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more. You can optionally [prepare your own dataset](#tune-on-your-dataset). For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial. ## Running the finetuning ```bash litgpt finetune_full tiiuae/falcon-7b \ --data Alpaca \ ``` Finetuning the falcon-7b model requires at least 8 GPUs with ~40 GB memory each. You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available. Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently. This script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name: ```bash litgpt finetune_full tiiuae/falcon-7b \ --data Alpaca \ --out_dir out/full/my-model-finetuned ``` If your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument. For instance, to fine-tune on MPS (the GPU on modern Macs), you can run ```bash litgpt finetune_full tiiuae/falcon-7b \ --data Alpaca \ --out_dir out/full/my-model-finetuned \ --precision 32-true ``` Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac. ## Test the model You can test the finetuned model with your own instructions by running: ```bash litgpt generate tiiuae/falcon-7b \ --prompt "Recommend a movie to watch on the weekend." \ --finetuned_path out/full/my-model-finetuned/lit_model_finetuned.pth ``` Output: ```text A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy... ``` If your GPU supports `bfloat16`, the script will automatically use it. ## Tune on your dataset You can easily train on your own instruction dataset saved in JSON format. 1. Create a JSON file in which each row holds one instruction-response pair. A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context. Below is an example json file: ```text [ { "instruction": "Arrange the given numbers in ascending order.", "input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template "output": "0, 2, 3, 4, 8" }, ... ] ``` 2. Run `litgpt finetune` by passing in the location of your data (and optionally other parameters): ```bash litgpt finetune tiiuae/falcon-7b \ --data JSON \ --data.json_path data/mydata.json \ --out_dir data/mydata-finetuned ``` ================================================ FILE: tutorials/finetune_lora.md ================================================ # Finetuning with LoRA / QLoRA [Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model. We demonstrate this method by instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU** with CUDA 11.8.   ## Preparation The steps here only need to be done once: 1. Follow the instructions in the [README](../README.md) to install the dependencies. 2. Download and convert the weights and save them in the `./checkpoints` folder. Weights can be downloaded following the instructions in the [download_model_weights](download_model_weights.md) documentation: LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more. You can optionally [prepare your own dataset](#tune-on-your-dataset). For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.   ## Running the Finetuning ```bash litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \ --data Alpaca ``` The finetuning requires at least one GPU with ~24 GB memory (RTX 3090). This script will save checkpoints periodically to the folder `out/`. > [!NOTE] > LoRA can be applied to not only `query`, `key` or `value` matrices, but also to `projection`, `mlp` and classification `head`. > According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): "LoRA on all linear transformer block layers are required to match full finetuning performance". > By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the arguments to `litgpt/finetune/lora.py` accordingly. Optionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type: ```bash litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \ --quantize "bnb.nf4" ``` and optionally with double-quantization: ```bash litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \ --quantize "bnb.nf4-dq" ``` The table below lists a comparison with different settings on a StableLM 3B model finetuned with LoRA on Alpaca for 1,000 iterations using a microbatch size of 1: | Settings | Training Memory | Training Time | Inference Memory | |---------------------------------------------|-----------------|---------------|-------------------| | Default (bf16-mixed) | 26.92 GB | 1.34 min | 21.43 GB | | --precision bf16-true | 9.69 GB | 1.24 min | 7.30 GB | | --precision bf16-true --quantize bnb.nf4 | 6.35 GB | 1.82 min | 3.20 GB | | --precision bf16-true --quantize bnb.nf4-dq | 6.19 GB | 1.87 min | 3.04 GB | The advantages of QLoRA-style quantization are more pronounced in larger models, such as Llama 2 7B. The table below summarizes the results for Llama 2 7B on Alpaca for 1,000 iterations using a microbatch size of 1: | Settings | Training Memory | Training Time | Inference Memory | |---------------------------------------------|------------------|---------------|------------------| | Default (bf16-mixed) | OutOfMemoryError | N/A | 40.21 GB | | --precision bf16-true | 21.30 GB | 2.36 min | 13.52 GB | | --precision bf16-true --quantize bnb.nf4 | 14.14 GB | 3.68 min | 4.57 GB | | --precision bf16-true --quantize bnb.nf4-dq | 13.84 GB | 3.83 min | 4.26 GB | For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).   ## Test the Model You can test the finetuned model with your own instructions by running: ```bash litgpt generate "out/lora/final" \ --prompt "Recommend a movie to watch on the weekend." ``` Output: ```text I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of... ``` If your GPU supports `bfloat16`, you can additionally pass `--precision "bf16-true"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2 GB for `--precision "32-full"`). In addition, you may use quantization methods, for example `--precision "bf16-true" --quantize "bnb.nf4"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.   ## Tune on Your Dataset You can easily train on your own instruction dataset saved in JSON format. 1. Create a JSON file in which each row holds one instruction-response pair. A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context. Below is an example json file: ```text [ { "instruction": "Arrange the given numbers in ascending order.", "input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template "output": "0, 2, 3, 4, 8" }, ... ] ``` 2. Run `litgpt finetune_lora` by passing in the location of your data (and optionally other parameters): ```bash litgpt finetune_lora checkpoints/stabilityai/stablelm-base-alpha-3b \ --data JSON \ --data.json_path data/mydata.json \ --out_dir out_dir/mydata-finetuned ``` 3. Test and use the finetuned model: ```bash litgpt chat out_dir/mydata-finetuned/final ``` or ```bash litgpt serve out_dir/mydata-finetuned/final ```   ## Merging LoRA Weights (Optional) Finetuning a model with LoRA generates a `lit_model.pth.lora` file. This file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space. > [!NOTE] > LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`. > Manual merging is only necessary if you want to use the checkpoint outside LitGPT. If desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint. The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed. For example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows: ```bash litgpt merge_lora "out/lora/step-002000" ``` The command above creates a full `lit_model.pth` checkpoint file. ================================================ FILE: tutorials/full_finetune_example.py ================================================ """ This script is meant to be the simplest possible starting point for full finetuning a GPT model using lightning fabric with code (not CLI). - no checkpoints - no out dir - no precision - no resume - no train/eval args (or any args in general) - no logger (only to terminal) - no grad accumulation and no other fancy stuff. To add all the above stuff, you can slowly add them in yourself by looking at the code in litgpt/finetune/full.py or the docs for litgpt/fabric. """ import os import lightning as L import torch import torch.nn as nn from litgpt.data import Alpaca from litgpt.model import GPT, Config from litgpt.tokenizer import Tokenizer from litgpt.utils import num_parameters # training params/args SEED = 1337 MODEL_NAME = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # try also "stabilityai/stablelm-base-alpha-3b"! BATCH_SIZE = 4 LR_WARMUP_STEPS = 100 MAX_STEPS = 601 def validate(model, val_dataloader): model.eval() loss = 0 with torch.no_grad(): for batch in val_dataloader: input_ids, targets = batch["input_ids"], batch["labels"] logits = model(input_ids) logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) loss += nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:]) fabric.print(f"Validation loss: {loss / len(val_dataloader)}") def train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader): for iter_num, batch in enumerate(train_dataloader): input_ids, targets = batch["input_ids"], batch["labels"] # get model preds (logits) logits = model(input_ids) logits = logits.reshape(-1, logits.size(-1)) # get loss targets = targets.reshape(-1) loss = nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:]) # update weights fabric.backward(loss) optimizer.step() optimizer.zero_grad() scheduler.step() # print train loss every 100 steps if iter_num % 100 == 0 or iter_num == 0: fabric.print(f"Train iter {iter_num} - loss {loss}") # validate every 300 steps if iter_num % 300 == 0 or iter_num == 0: validate(model, val_dataloader) model.train() iter_num += 1 if iter_num >= MAX_STEPS: break def main(fabric): fabric.seed_everything(SEED) # setup data, make tokenizer and make dataloaders data = Alpaca() tokenizer = Tokenizer(checkpoint_dir=f"checkpoints/{MODEL_NAME}") data.connect(tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_length=1024) data.setup() train_dataloader = data.train_dataloader() val_dataloader = data.val_dataloader() train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) # print how many steps in an epoch fabric.print(f"Steps in an epoch: {len(train_dataloader)}") # setup model config = Config.from_file(f"checkpoints/{MODEL_NAME}/model_config.yaml") model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") model = fabric.setup(model) # setup optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.02, betas=(0.9, 0.95)) optimizer = fabric.setup_optimizers(optimizer) # setup lr scheduler scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / LR_WARMUP_STEPS) scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(MAX_STEPS - LR_WARMUP_STEPS)) scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[LR_WARMUP_STEPS]) # Start training!!! train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader) if __name__ == "__main__": # check that the model exists (downloaded to ./checkpoints/) if not os.path.exists(f"checkpoints/{MODEL_NAME}"): print(f"Model {MODEL_NAME} not found. Please download it using `litgpt download --repo {MODEL_NAME}`") exit() ### Setup and launch fabric = L.Fabric(devices="auto", strategy="auto") fabric.launch(main) ================================================ FILE: tutorials/inference.md ================================================ # Inference We demonstrate how to run inference (next token prediction) with the GPT base model in the [`litgpt generate`](../litgpt/generate/base.py) command: ```bash litgpt generate stabilityai/stablelm-base-alpha-3b \ --prompt "Hello, my name is" ``` Output: ```text Hello, my name is Levi Durrer, I'm an Austrian journalist - Chairman of the Press Blair Party, with 37 years in the Press Blair International, and two years in the Spectre of Austerity for the other. I'm crossing my fingers that you will feel ``` The script assumes you have downloaded and converted the weights as described [here](download_model_weights.md). This will run the 3B pre-trained model and require ~7 GB of GPU memory using the `bfloat16` datatype. ## Run interactively You can also chat with the model interactively: ```bash litgpt chat stabilityai/stablelm-tuned-alpha-3b ``` This script can work with any checkpoint. For the best chat-like experience, we recommend using it with a checkpoints fine-tuned for chatting such as `stabilityai/stablelm-tuned-alpha-3b` or `togethercomputer/RedPajama-INCITE-Chat-3B-v1`. > [!TIP] > Use `--multiline true` to work with inputs that span multiple lines. ## Run a large model on one smaller device Check out our [quantization tutorial](quantize.md). ## Run a large model on multiple smaller devices We offer two scripts to leverage multiple devices for inference. ### [`litgpt generate_sequentially`](../litgpt/generate/sequentially.py) Allows you to run models that wouldn't fit in a single card by partitioning the transformer blocks across all your devices and running them sequentially. For instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations. With 80 transformer layers, we could partition them across 8, 5, 4, or 2 devices. ```shell litgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \ --max_new_tokens 256 \ --num_samples 2 ``` Using A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable. | Devices | Max GPU RAM | Token/sec | |---------|-------------|-----------| | 2 | OOM | - | | 4 | 35.64 GB | 7.55 | | 5 | 28.72 GB | 7.49 | | 8 | 18.35 GB | 7.47 | Note that the memory usage will also depend on the `max_new_tokens` value used. The script also supports quantization, using 4-bit precision, we can now use 2 GPUs ```shell litgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \ --max_new_tokens 256 \ --num_samples 2 \ --quantize bnb.nf4-dq ``` | Devices | Max GPU RAM | Token/sec | |---------|-------------|-----------| | 2 | 20.00 GB | 8.63 | | 4 | 10.80 GB | 8.23 | | 5 | 8.96 GB | 8.10 | | 8 | 6.23 GB | 8.18 | Smaller devices can also be used to run inference with this technique. ### [`litgpt generate_tp`](../litgpt/generate/tp.py) Uses tensor parallelism (TP) to run models that wouldn't fit in a single card by sharding the MLP and Attention QKV linear layers across all your devices. For instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations. The requirement is that the intermediate size (for the MLP) and the QKV size (for attention) is divisible by the number of devices. With an intermediate size of 28672, we can use 2, 4, 7, or 8 devices. With a QKV size of 10240 we can use 2, 4, 5, or 8 devices. Since the script is configured to shard both, the intersection is used: we can only use 2, 4, or 8 devices. ```shell litgpt generate_tp meta-llama/Llama-2-70b-chat-hf \ --max_new_tokens 256 \ --num_samples 2 ``` Using A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable. | Devices | Max GPU RAM | Token/sec | |---------|-------------|-----------| | 2 | OOM | - | | 4 | 35.46 GB | 9.33 | | 8 | 18.19 GB | 8.61 | Note that the memory usage will also depend on the `max_new_tokens` value used. The script also supports quantization, using 4-bit precision, we can now use 2 GPUs ```shell litgpt generate_tp meta-llama/Llama-2-70b-chat-hf \ --max_new_tokens 256 \ --num_samples 2 \ --quantize bnb.nf4-dq ``` | Devices | Max GPU RAM | Token/sec | |---------|-------------|-----------| | 2 | 19.79 GB | 6.72 | | 4 | 10.73 GB | 6.48 | | 8 | 6.15 GB | 6.20 | Smaller devices can also be used to run inference with this technique. ================================================ FILE: tutorials/mkdocs.yml ================================================ site_name: LitGPT Tutorials plugins: - pagetree theme: name: material ================================================ FILE: tutorials/oom.md ================================================ ## Dealing with out-of-memory (OOM) errors If you got this error while running a script ```bash OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacity of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF ``` it means that your GPU memory size wasn't big enough for the model and script configuration. Here's a few things you can try: ### Reduce the micro batch size Adjust the `--train.micro_batch_size` argument in the fine-tuning and pretraining scripts. This variable determines the number of samples loaded per iteration. A smaller value will simply load fewer samples simultaneously. The minimum value is 1. Experiment with different micro batch sizes to find a balance between memory consumption and computational efficiency. Smaller micro batch sizes consume less memory but may result in slower training convergence. Conversely, larger micro batch sizes require more memory but can accelerate training speed. ### Reduce the model's context length The context length (`block_size` in the code) plays a significant role in running models with attention. * The pretraining scripts are configured to use the full context length of the model to train. * The finetuning scripts are configured to use the longest sample length of the training data to avoid allocating unnecessary memory (`--train.max_seq_length` argument). If that's longer than the model's context length, an error is raised. If you try to run a batch that is longer than this, an error is raised. However, your hardware may not support such large context lengths. Here's what you can do: * For the pretraining scripts, you can simply reduce the `Config(block_size=...)` value. * For the finetuning scripts, you can trim the length of the samples in your dataset. All the finetuning scripts expose a `--data.max_seq_length=...` argument. This might also be useful in cases where sample lengths are highly unbalanced, as the presence of a single very long sample would incur a larger memory usage for all other shorter samples. For example, the median length of the samples in Alpaca is 110 tokens. Truncating the Alpaca dataset to 256 max tokens reduces the memory requirements of a Falcon 7B model from 23.52 GB to 15.73 GB. For more information about the dataset truncation, please see the *Truncating datasets* section in the [prepare_dataset.md](prepare_dataset.md) tutorial. Keep in mind that reducing the context length will affect the modelling performance on text sequences longer than the limit. ### Use lower precision Our scripts expose the `--precision` argument, this directly impacts the memory usage. Using true lower precision (`16-true`, `bf16-true`) reduces the memory usage by half compared to `32-true`, however, the model might start producing NaNs due to the limited range of representable values. Mixed precision training (`16-mixed`, `bf16-mixed`) provides better stability but offers limited memory reduction. ### Do sharding across multiple GPUs For exceptionally large models, the aforementioned techniques might still not suffice. If you have multiple GPUs available, you can trade off memory for speed by changing the `--devices 1` argument in the scripts. Enabling this option enables a parallelism technique (FSDP), sharding the memory across different GPUs. The default configuration already uses activation checkpointing, but you can enable CPU offloading by changing the `cpu_offload=False` argument in the scripts. ### Try a different optimizer Our scripts use the [`AdamW` optimizer](https://pytorch.org/docs/main/generated/torch.optim.AdamW.html). It maintains 2 states for each trainable parameter of the model, meaning that the optimizer memory is double compared to an optimizer like [`SGD`](https://pytorch.org/docs/main/generated/torch.optim.SGD.html). You can try replacing it with your optimizer of choice that is lighter in memory requirements. Keep in mind that different optimizers have distinct optimization behaviors, so it's essential to assess their impact on the training process and model performance. An example would be the recently published [Sophia](https://arxiv.org/abs/2305.14342) or [Lion](https://arxiv.org/abs/2302.06675) optimizers. This suggestion is particularly relevant for pretraining, as the trainable parameters in the model represent a small subset of the total in the fine-tuning scripts. ================================================ FILE: tutorials/prepare_dataset.md ================================================ # Preparing Datasets Below is a table of all datasets that are currently supported in LitGPT: | Name | Task | Size | Reference Repo | Paper / Blog | Data License | |--------------|-------------|---------------------|--------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Alpaca | Finetuning | 51,759 samples | [URL](https://github.com/tatsu-lab/stanford_alpaca) | [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html) | Attribution-NonCommercial 4.0 International, [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html) | | Alpaca-2k | Finetuning | 2000 samples | [URL](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) | See Alpaca above | See Alpaca Above | | Alpaca-GPT4 | Finetuning | 52,002 samples | [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | [URL](https://arxiv.org/abs/2304.03277) | Attribution-NonCommercial 4.0 International, [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/DATA_LICENSE) | | Alpaca Libre | Finetuning | 55,370 samples | [URL](https://github.com/mobarski/alpaca-libre) | - | CC0/MIT, [URL](https://github.com/mobarski/alpaca-libre) | | Deita | Finetuning | 9,500 samples | [URL](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft/tree/main/data) | [URL](https://arxiv.org/abs/2312.15685) | MIT [URL](https://huggingface.co/datasets/hkust-nlp/deita-10k-v0/blob/main/README.md) | | Dolly | Finetuning | 15,011 samples | [URL](https://github.com/databrickslabs/dolly/tree/master/data) | [URL](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | CC-BY-SA, [URL](https://github.com/databrickslabs/dolly#model-overview) | | FLAN | Finetuning | 1,753,240 samples | [UR](https://huggingface.co/datasets/Muennighoff/flan) | [URL](https://blog.research.google/2023/02/the-flan-collection-advancing-open.html) | Subset dependent | | LongForm | Finetuning | 23,652 samples | [URL](https://github.com/akoksal/LongForm) | [URL](https://arxiv.org/abs/2304.08460) | No information provided and subset-dependent, [URL](https://github.com/akoksal/LongForm) | | LIMA | Finetuning | 1,084 samples | [URL](https://huggingface.co/datasets/GAIR/lima) | [URL](https://arxiv.org/abs/2305.11206) | "If the source data of LIMA has a stricter license than CC BY-NC-SA, the LIMA dataset follows the same. Otherwise, it follows the CC BY-NC-SA license", [URL](https://huggingface.co/datasets/GAIR/lima#license) | | OpenWeb Text | Pretraining | 8,013,769 documents | [URL](https://github.com/jcpeterson/openwebtext) | [URL](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) | Unspecified | | TinyLlama | Pretraining | 1 T tokens | [URL](https://github.com/jzhang38/TinyLlama) | [URL](https://arxiv.org/abs/2401.02385) | | | TinyStories | Pretraining | 4,967,871 stories | [URL](https://huggingface.co/datasets/roneneldan/TinyStories) | [URL](https://arxiv.org/abs/2305.07759) | CDLA-Sharing-1.0 |   ## Preparation The steps here only need to be done once before preparing the finetuning datasets in the following subsections: 1. Follow the instructions in the [README](../README.md) to install the dependencies. 2. Download and convert the weights following our [guide](download_model_weights.md). For the following examples, we will focus on finetuning with the `litgpt finetune_lora` command and use a Falcon 7B model. However, the same steps apply to all other models and finetuning scripts. Please read the [tutorials/finetune_*.md](.) documents for more information about finetuning models.   > [!IMPORTANT] > By default, the maximum sequence length is obtained from the model configuration file. In case you run into out-of-memory errors, especially in the cases of LIMA and Dolly, > you can try to lower the context length by setting the `--train.max_seq_length` parameter, for example, `litgpt finetune lora --train.max_seq_length 256`. For more information on truncating datasets, see the *Truncating datasets* section in the Alpaca section near the top of this article.   ### Alpaca The Alpaca dataset consists of 52,000 instructions and demonstrations produced by OpenAI's text-davinci-003 engine. This data is used in instruction-tuning, helping improve the performance of language models to follow instructions. In its development, the creators leveraged the data generation methodology from the [Self-Instruct framework](https://github.com/yizhongw/self-instruct). The original [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) dataset can be used as follows: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Alpaca ```   > [!TIP] > Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.   #### Truncating datasets By default, the finetuning scripts will determine the size of the longest tokenized sample in the dataset to determine the block size. However, if you are willing to truncate a few examples in the training set, you can reduce the computational resource requirements significantly. For instance you can set a sequence length threshold via `--train.max_seq_length`. We can determine an appropriate maximum sequence length by considering the distribution of the data sample lengths shown in the histogram below. In this case, a cut-off of 256 may be a reasonable choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Alpaca \ --train.max_seq_length 256 ``` For comparison, the Falcon 7B model requires 23.52 GB of memory for the original Alpaca dataset and 15.73 GB of memory for the truncated Alpaca dataset when finetuning with LoRA using a micro batchsize of 1 and bfloat-16 precision.   ### Alpaca-2k [Alpaca-2k](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) is a smaller, 2000-sample subset of Alpaca described above. ```bash litgpt finetune_lora "tiiuae/falcon-7b" \ --data Alpaca2k ```   > [!TIP] > Use `litgpt_finetune --data.help Alpaca2k` to list additional dataset-specific command line options.   The Alpaca-2k dataset distribution is shown below. ### Alpaca-GPT4 The Alpaca-GPT4 was built by using the prompts of the original Alpaca dataset and generate the responses via GPT 4. The dataset consists of 52,000 instructions and responses. The original [Alpaca-GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset can be used as follows: ```bash litgpt finetune lora "tiiuae/falcon-7b" \ --data AlpacaGPT4 ```   > [!TIP] > Use `litgpt_finetune --data.help AlpacaGPT4` to list additional dataset-specific command line options.   The Alpaca-GPT4 dataset distribution is shown below.   ### Alpaca Libre [Alpaca Libre](https://github.com/mobarski/alpaca-libre) is a reimplementation or alternative to Alpaca using the same formatting. To use Alpaca Libre instead of the original Alpaca dataset, use the following command: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Alpaca \ --data.file_url "https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json" \ --data.file_name "alpaca_libre_data_cleaned_archive.json" ```   > [!TIP] > Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.   The Alpaca Libre dataset distribution is shown below. You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 256 may be a good choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Alpaca \ --data.file_url "https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json" \ --data.file_name "alpaca_libre_data_cleaned_archive.json" \ --train.max_seq_length 256 ```   ### Deita The Deita dataset (short for Data-Efficient Instruction Tuning for Alignment) is a collection of 9500 prompts and responses, as described in the [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) paper. Using Falcon 7b as an example, we can use the dataset as follows: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Deita ```   > [!TIP] > Use `litgpt finetune --data.help Deita` to list additional dataset-specific command line options.   Deita contains multiturn conversations. By default, only the first instruction-response pairs from each of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions and responses, set `--data.include_multiturn_conversations True`, which will include all multiturn conversations as regular prompt-response pairs. Considering the multiturn-answers, the dataset consists of 209,272 prompt-response pairs. The Deita dataset distribution without including multit-turn conversations is shown below. The Deita dataset distribution including multit-turn conversations is depicted in the following histogram. You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Deita \ --train.max_seq_length 512 ```   ### Dolly The Dolly dataset is a publicly available collection of 15k instruction-following entries created by Databricks. It spans multiple behavioral domains, as described in the [InstructGPT paper](https://arxiv.org/abs/2203.02155) paper. These include areas like brainstorming, classification, closed QA, content creation, information retrieval, open QA, and summary generation. The usage is similar to the Alpaca dataset described above. Using Falcon 7b as an example, we can use the dataset as follows: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Dolly ```   > [!TIP] > Use `litgpt finetune --data.help Dolly` to list additional dataset-specific command line options.   The Dolly dataset distribution is shown below. You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data Dolly \ --train.max_seq_length 256 ```   ### LongForm LongForm is a semi-synthetic dataset based on raw text corpora for which the instructions were generated via an LLM. For more details about the instruction-generation process, please refer to the [LongForm research paper](https://arxiv.org/abs/2304.08460) by Köksal et al. According to the research paper, a Llama 7B model trained on LongForm achieves substantially better performance than the same Llama model trained on the 2x larger Alpaca dataset. LongForm consists of 23,652 training samples, 2,042 validation samples, and 2,045 test samples. (In LitGPT, the validation samples are currently not used.) The more detailed dataset composition is as follows based on a table taken from the [dataset repository](https://github.com/akoksal/LongForm): | **Type** | **Source** | **Number of Examples** | |------------------------|----------------|------------------------| | **Corpora** | C4 | 10,000 | | | Wikipedia | 5,000 | | **Structured Corpora** | Stack Exchange | 4,380 | | | WikiHow | 2,500 | | **Tasks** | NIv2 | 3,684 | | | Big Bench | 600 | | | BEA-GEC | 1,203 | | | Enron | 372 | | **Total** | | 27,739 | | | | | | **Train** | | 23,652 | | **Validation** | | 2,042 | | **Test** | | 2,045 | License information is not provided but would depend on the individual subsets listed above. The LongForm dataset distribution is shown below. You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 1500 may be a good choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data LongForm \ --train.max_seq_length 1500 ```   > [!TIP] > Use `litgpt finetune --data.help LongForm` to list additional dataset-specific command line options.     ### LIMA The LIMA dataset is a collection of 1,000 carefully curated prompts and responses, as described in the [LIMA: Less Is More for Alignment](https://arxiv.org/abs/2305.11206) paper. The dataset is sourced from three community Q&A websites: Stack Exchange, wikiHow, and the Pushshift Reddit Dataset. In addition, it also contains prompts and answers written and collected by the authors of the LIMA paper. The usage is similar to the Dolly dataset described above except that it requires an Hugging Face access token that you need to copy & paste from your Hugging Face account. Using Falcon 7b as an example, we can use the dataset as follows: ```bash export HF_TOKEN="insert_your_huggingface_token_here" litgpt finetune lora \ --data LIMA \ --checkpoint_dir "tiiuae/falcon-7b" ```   > [!TIP] > Use `litgpt finetune --data.help LIMA` to list additional dataset-specific command line options.   LIMA contains a handful of multiturn conversations. By default, only the first instruction-response pairs from each of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions and responses, set `--data.include_multiturn_conversations True`. The LIMA dataset distribution is shown below. You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data LIMA \ --train.max_seq_length 512 ```   ### FLAN FLAN is a collection of several dataset subsets by Google. In particular, the provided script in LitGPT loads the subsets from [here](https://huggingface.co/datasets/Muennighoff/flan). By default, all subsets (1,386,050 samples) and validations sets (367,190 subsets) are combined into a single dataset: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data FLAN ``` However, you can also select individual subsets via comma-separated strings as follows: ```bash litgpt finetune lora tiiuae/falcon-7b \ --data FLAN \ --data.subsets "aeslc_10templates,ag_news_subset_10templates,anli_r1_10templates" ```   > [!TIP] > Use `litgpt finetune --data.help FLAN` to list additional dataset-specific command line options.   You can find a list of all 66 supported subsets [here](https://huggingface.co/datasets/Muennighoff/flan).   ## Preparing Custom Datasets for Instruction Finetuning The models in LitGPT expect datasets for instruction finetuning in the following format: ```text [ { "instruction": "Write a limerick about a pelican.”, "input": "", "output": "There once was a pelican so fine, \nHis beak was as colorful as sunshine,\nHe would fish all day,\nIn a very unique way,\nThis pelican was truly divine!\n\n\n" }, { "instruction": "Identify the odd one out from the group.", "input": "Carrot, Apple, Banana, Grape", "output": "Carrot\n\n" }, ] ``` (Note that depending on the task, the `"input"` text can be an empty string, as shown above.) You can use your own data in LitGPT by either reading in a JSON file in the format shown above or by implementing a custom `DataModule`.   ### Preparing Custom Datasets From a JSON File You can prepare custom dataset using a JSON file where each row is a dictionary with these keys: - `instruction`: Column which will describe the task. - `input`: A string holding a special input value for the instruction. This applies to some samples, and in others, this is empty (empty string). - `output`: The expected response > If any of the fields are missing, then the script will fail to read the dataset. Then simply run any of the finetuning scripts with this input: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data JSON \ --data.json_path path/to/your/data.json \ --data.val_split_fraction 0.1 ``` You can also customize how the dataset is read by using these additional parameters - `val_split_fraction`: The fraction of the data to split. Defaults to `0.1` - `seed`: The seed value to reproduce the same random splits for train and test data. - `mask_inputs`: Whether to mask the prompt section from the label (with `ignore_index`). - `ignore_index`: The index to use for labels that should be ignored. Defaults to `-100` (used when `mask_inputs` is `True`). To use the settings described above, you can add the respective command line arguments when calling the finetuning scripts as shown in the example below: ```bash litgpt finetune_lora tiiuae/falcon-7b \ --data JSON \ --data.json_path path/to/your/data.json \ --data.val_split_fraction 0.1 \ --data.seed 42 \ --data.mask_inputs False \ --data.ignore_index -100 ``` You can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split.   > [!TIP] > Use `litgpt finetune --data.help JSON` to list additional dataset-specific command line options.   ### Preparing Custom Datasets Using DataModule If you don't have a JSON file following the format described in the previous section, the easiest way to prepare a new dataset is to copy and modify one of the existing data modules in LitGPT: - [`litgpt/data/alpaca.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/alpaca.py) (if you plan to load a dataset from a JSON file); - [`litgpt/data/lima.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/lima.py) (if you plan to load a dataset using the `datasets` Python library). Note that you only need to modify a small fraction of the code file, namely the portion that downloads and formats the training data (see the `prepare_data` and `setup()` methods).   ## Preparing Pretraining Datasets In addition to the finetuning dataset described above, LitGPT also supports several datasets for pretraining. The pretraining datasets are described in more detail in the following separate tutorial documents: - [Pretrain TinyLlama on Slimpajama and Starcoder](./pretrain_tinyllama.md) ================================================ FILE: tutorials/pretrain.md ================================================ # Pretrain LLMs with LitGPT This document explains how to pretrain LLMs using LitGPT.   ## Using the `litgpt pretrain` command You can pretrain models in LitGPT using the `litgpt pretrain` API starting with any of the available architectures listed by calling `litgpt pretrain list` without any additional arguments:   > [!TIP] > To install all required dependencies before pretraining, first run `pip install "litgpt[all]"`.   ```bash litgpt pretrain list ``` Shown below is an abbreviated list: ``` ValueError: Please specify --model_name . Available values: Camel-Platypus2-13B ... Gemma-2b ... Llama-2-7b-hf ... Mixtral-8x7B-v0.1 ... pythia-14m ``` For demonstration purposes, we can pretrain a small 14 million-parameter Pythia model on the small TinyStories dataset using the [debug.yaml config file](https://github.com/Lightning-AI/litgpt/blob/main/config_hub/pretrain/debug.yaml) as follows: ```bash litgpt pretrain pythia-14m \ --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/debug.yaml ```   ## Pretrain on custom data The simplest way to get started with pretraining on a small custom dataset is by using the `TextFiles` data module, which lets you pretrain a dataset from a folder containing plain text files.   > [!NOTE] > This approach adds a beginning-of-sequence token at the beginning of each text file. However, it otherwise assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable in case a text file conists of multiple documents.   > [!WARNING] > Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size. > For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData`.   For instance, assume you stored a number of text files in a `custom_pretraining_dataset` folder (we recommend avoiding small files and concatenating them to files of at least 50 Mb for efficiency): ```bash ~ ls -lh custom_pretraining_data total 3225M -rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_1.txt -rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_2.txt -rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_3.txt -rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_4.txt -rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_5.txt ... ``` You can then use the `TextFiles` API to pretrain a model (here a small `pythia-14m` model for illustration purposes) from scratch as follows: ```bash litgpt download EleutherAI/pythia-14m \ --tokenizer_only true litgpt pretrain pythia-14m \ --tokenizer_dir EleutherAI/pythia-14m \ --data TextFiles \ --data.train_data_path custom_pretraining_data \ --train.lr_warmup_steps=200 \ --optimizer AdamW \ --optimizer.lr 0.005 ```   > [!TIP] > Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.     ## Continued pretraining on custom data Often, it makes sense to adopt an existing pretrained model and further pretrain it on our own custom data. The existing pretrained model can be either our own pretrained model or a model downloaded from a model hub. The following subsections illustrate three typical scenarioes: 1. Starting from a downloaded base model 2. Continuing the pretraining after interruption 3. Further pretraining on a different dataset   > [!NOTE] > This approach assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable.   > [!WARNING] > Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size. > For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData --data.path path/to/your/data`.   ### 1) Continued pretraining when starting from a downloaded base model For instance, let's assume we download a Pythia model: ```bash litgpt download EleutherAI/pythia-160m ``` Next, assume we have a custom dataset stored in text files similar to the *Pretrain on custom data* above. We can further pretrain the Pythia model via the `--initial_checkpoint_dir` setting as follows: ```bash litgpt pretrain pythia-160m \ --initial_checkpoint_dir EleutherAI/pythia-160m \ --tokenizer_dir EleutherAI/pythia-160m \ --out_dir ./new_pretrained_checkpoint \ --data TextFiles \ --data.train_data_path custom_pretraining_data \ --train.max_tokens 1_000_000 ```   > [!TIP] > Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.   ### 2) Continued pretraining after interruption In case a you interrupted a training run, you can continue it with the `--resume` option, for example: ```bash litgpt pretrain pythia-160m \ --resume "auto" \ --tokenizer_dir EleutherAI/pythia-160m \ --out_dir ./new_pretrained_checkpoint \ --data TextFiles \ --data.train_data_path custom_pretraining_data \ --train.max_tokens 1_000_000 ```   ### 3) Continued pretraining on a new dataset Suppose you pretrained a model using the examples above. To further pretrain the model on a new dataset, you first need to convert the pretrained checkpoint via the following command: ```bash litgpt convert_pretrained_checkpoint ./new_pretrained_checkpoint/final ./new_pretrained_checkpoint_converted ``` Then, you can pretrain the converted model on the new dataset as follows: ```bash litgpt pretrain pythia-160m \ --initial_checkpoint_dir ./new_pretrained_checkpoint_converted \ --tokenizer_dir EleutherAI/pythia-160m \ --out_dir ./new_pretrained_checkpoint_2 \ --data TextFiles \ --data.train_data_path custom_pretraining_data_2 \ --train.max_tokens 1_000_000 ```   ## Pretrain a 1.1B TinyLlama model You can find an end-to-end LitGPT tutorial for pretraining a TinyLlama model using LitGPT [here](pretrain_tinyllama.md).   ## Optimize LitGPT pretraining with Lightning Thunder [Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) is a source-to-source compiler for PyTorch, which is fully compatible with LitGPT. In experiments, Thunder resulted in a 40% speed-up compared to using regular PyTorch when finetuning a 7B Llama 2 model. For more information, see the [Lightning Thunder extension README](https://github.com/Lightning-AI/lightning-thunder).   ## Project templates The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:   | | | |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |

[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[

](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) | | [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | | | | ================================================ FILE: tutorials/pretrain_tinyllama.md ================================================ # Pretrain TinyLlama This tutorial will walk you through pretraining [TinyLlama](https://github.com/jzhang38/TinyLlama/). > [!TIP] > To get started with zero setup, clone the [TinyLlama studio on Lightning AI](https://lightning.ai/lightning-ai/studios/llm-pretrain-tinyllama-1-1b).   ## What's TinyLlama? [TinyLlama](https://github.com/jzhang38/TinyLlama/) is architecturally the same as Meta AI's LLama 2, but only has 1.1B parameters and is instead trained on multiple epochs on a mix of [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) and [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) datasets. Here is a quick fact sheet: | Name | Description | |-------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| | Parameters | 1.1B | | Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size: 5632 | | Sequence Length | 2048 | | Learning Rate | 4e-4 | | Learning Rate Schedule | Cosine with 2000 warmup steps | | Training Data | [SlimPajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) (893 GB), [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) (290 GB) | | Combined Dataset Size | Around 950B tokens | | Total Tokens During Training | 3 trillion (3 epochs) | | Time to complete training | ~ 4 weeks with 64 A100 GPUs | | Model FLOPs Utilization (MFU) | 52% | (this table was sourced from the author's [README](https://github.com/jzhang38/TinyLlama/))   ## Download datasets You can download the data using git lfs: ```bash # Make sure you have git-lfs installed (https://git-lfs.com): sudo apt install git-lfs ``` ```bash git clone https://huggingface.co/datasets/cerebras/slimpajama-627b data/slimpajama-raw git clone https://huggingface.co/datasets/bigcode/starcoderdata data/starcoderdata-raw ``` Around 1.2 TB of disk space is required to store both datasets.   ## Prepare the datasets for training In order to start pretraining litgpt on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `litdata` optimization pipeline and streaming dataset. First, install additional dependencies for preprocessing: ```bash pip install '.[all]' ``` You will need to have the tokenizer config available: ```bash litgpt download meta-llama/Llama-2-7b-hf \ --access_token your_hf_token \ --tokenizer_only true ``` Then, run the preprocessing script for each dataset and split. You will require **1.1 TB** of disk space for Starcoder and **2.5** TB of space for the SlimPajama dataset. **Starcoder:** ```bash python litgpt/data/prepare_starcoder.py \ --input_dir data/starcoderdata-raw \ --output_dir data/starcoder \ --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf ``` **SlimPajama:** ```bash python litgpt/data/prepare_slimpajama.py \ --input_dir data/slimpajama-raw/validation \ --output_dir data/slimpajama/val \ --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf python litgpt/data/prepare_slimpajama.py \ --input_dir data/slimpajama-raw/test \ --output_dir data/slimpajama/test \ --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf python litgpt/data/prepare_slimpajama.py \ --input_dir data/slimpajama-raw/train \ --output_dir data/slimpajama/train \ --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf ``` If you want to run on a small slice of the datasets first, pass the flag `--fast_dev_run=true` to the commands above. In the above we are assuming that you will be using the same tokenizer as used in LlaMA/TinyLlama, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.   ## Pretraining Running the pretraining script with its default settings requires at least 8 A100 GPUs. ```bash litgpt pretrain --config config_hub/pretrain/tinyllama.yaml ```   > [!TIP] > Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.   The script will save checkpoints periodically to the folder `out/`. By default, the `pretrain` script will pretrain the model with FSDP in `bfloat16` mixed precision and gradient accumulation. Note that `pretrain` is not actually a model-specific training script, so feel free [try other configurations](../config_hub) or change the model type and size by passing a different string to the model name argument, for example: ```shell litgpt pretrain Gemma-2b ``` The currently supported model names can be listed by executing `litgpt pretrain` without any additional arguments. Keep in mind that training with a single machine will take weeks. To speed up the process, you'll need access to a cluster. Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/fundamentals/launch.html#launch-on-a-cluster) to launch the script across machines: - [Lightning AI](https://lightning.ai/docs/fabric/stable/guide/multi_node/cloud.html) - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html) - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html) The script exposes several hyperparameters you can tweak through the command line. For instance, `--train.micro_batch_size` should be adjusted so the process will use the available GPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed [Dealing with out-of-memory (OOM) errors](oom.md) guide. Last, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker. LitGPT supports multiple experiment trackers including: - **TensorBoard** (default): Local visualization with TensorBoard - **CSV Logger**: Simple local logging to CSV files - **WandB**: Cloud-based experiment tracking with Weights & Biases - **MLflow**: MLflow experiment tracking - **[LitLogger](https://github.com/Lightning-AI/LitLogger)**: Lightning.ai's native experiment tracking (set `--logger_name=litlogger`) As an example, we included WandB (set `--logger_name=wandb`) to show how you can integrate any experiment tracking framework. For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).   ## Resume training The checkpoints saved during pretraining contain all the information to resume if needed. Simply rerun the script with the `--resume` argument added: ```bash litgpt pretrain tiny-llama\ --config config_hub/pretrain/tinyllama.yaml \ --resume out/pretrain/tiny-llama/step-00060500 ``` **Important:** Each checkpoint is a directory. Point to the directory, not the 'lit_model.pth' file inside of it.   > [!TIP] > Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.     ## Export checkpoints After training is completed, you can convert the checkpoint to a format that can be loaded for evaluation, inference, finetuning etc. ```bash litgpt convert_pretrained_checkpoint out/pretrain/tiny-llama/step-00060500 \ --output_dir checkpoints/tiny-llama/final ``` After conversion, the output folder will contain these files: ``` checkpoints/tiny-llama/final ├── model_config.yaml ├── lit_model.pth ├── tokenizer_config.json ├── tokenizer.json └── tokenizer.model ``` You can then use this checkpoint folder to run [evaluation](evaluation.md), [inference](inference.md), [finetuning](finetune_lora.md) or [process the checkpoint further](convert_lit_models.md).   ## Project templates The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:   | | | |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |

[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[

](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) | | [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | | | | ================================================ FILE: tutorials/python-api.md ================================================ # LitGPT Python API This is a work-in-progress draft describing the current LitGPT Python API (experimental and subject to change). ## Model loading Use the `LLM.load` method to load a model from a LitGPT model checkpoint folder. For example, consider loading a Phi-2 model. If a given checkpoint directory `"microsoft/phi-2"` does not exist as a local checkpoint directory, the model will be downloaded automatically from the HF Hub (assuming that `"microsoft/phi-2"` is a valid repository name): ```python from litgpt import LLM llm_1 = LLM.load("microsoft/phi-2") ``` ``` config.json: 100%|████████████████████████████████████████████████| 735/735 [00:00<00:00, 7.75MB/s] generation_config.json: 100%|█████████████████████████████████████| 124/124 [00:00<00:00, 2.06MB/s] model-00001-of-00002.safetensors: 100%|███████████████████████████| 5.00G/5.00G [00:12<00:00, 397MB/s] model-00002-of-00002.safetensors: 100%|███████████████████████████| 564M/564M [00:01<00:00, 421MB/s] model.safetensors.index.json: 100%|███████████████████████████████| 35.7k/35.7k [00:00<00:00, 115MB/s] tokenizer.json: 100%|█████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 21.5MB/s] tokenizer_config.json: 100%|██████████████████████████████████████| 7.34k/7.34k [00:00<00:00, 80.6MB/s] ```   > [!NOTE] > To get a list of all supported models, execute `litgpt download list` in the command line terminal.  
If you attempt to load the model again, LitGPT will load this model from a local directory since it's already been downloaded: ```python llm_2 = LLM.load("microsoft/phi-2") ``` If you created a pretrained or finetuned model checkpoint via LitGPT, you can load it in a similar fashion: ```python my_llm = LLM.load("path/to/my/local/checkpoint") ```   ## Generate/Chat Generate output using the `.generate` method: ```python from litgpt import LLM llm = LLM.load("microsoft/phi-2") text = llm.generate("What do Llamas eat?", top_k=1, max_new_tokens=30) print(text) ``` ``` Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract ``` Alternatively, stream the response one token at a time: ```python result = llm.generate("hi", stream=True) for e in result: print(e, end="", flush=True) ``` ``` Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract ```   ## Saving models After finetuning or modifying a model, you can save it to disk using the `.save()` method: ```python from litgpt import LLM llm = LLM.load("microsoft/phi-2") # ... perform finetuning or modifications ... llm.save("path/to/save/directory") ``` The saved checkpoint can then be loaded later: ```python llm = LLM.load("path/to/save/directory") ```   ## Random weights To start with random weights, for example, if you plan a pretraining script, initialize the model with `init="random"`. Note that this requires passing a `tokenizer_dir` that contains a valid tokenizer file. ```python from litgpt.api import LLM llm = LLM.load("pythia-160m", init="random", tokenizer_dir="EleutherAI/pythia-160m") ```   ## Multi-GPU strategies By default, the model is loaded onto a single GPU. Optionally, you can use the `.distribute()` method with the "sequential" or "tensor_parallel" `generate_strategy` settings. ### Sequential strategy The `generate_strategy="sequential"` setting loads different parts of the models onto different GPUs. The goal behind this strategy is to support models that cannot fit into single-GPU memory. (Note that if you have a model that can fit onto a single GPU, this sequential strategy will be slower.) ```python from litgpt.api import LLM llm = LLM.load( "microsoft/phi-2", distribute=None ) llm.distribute( generate_strategy="sequential", devices=4, # Optional setting, otherwise uses all available GPUs fixed_kv_cache_size=256 # Optionally use a small kv-cache to further reduce memory usage ) ``` ``` Using 4 devices Moving '_forward_module.transformer.h.31' to cuda:3: 100%|██████████| 32/32 [00:00<00:00, 32.71it/s] ``` After initializing the model, the model can be used via the `generate` method similar to the default `generate_strategy` setting: ```python text = llm.generate("What do llamas eat?", max_new_tokens=100) print(text) ``` ``` Llamas are herbivores and their diet consists mainly of grasses, plants, and leaves. ```   ### Tensor parallel strategy The sequential strategy explained in the previous subsection distributes the model sequentially across GPUs, which allows users to load models that would not fit onto a single GPU. However, due to this method's sequential nature, processing is naturally slower than parallel processing. To take advantage of parallel processing via tensor parallelism, you can use the `generate_strategy="tensor_parallel" setting. However, this method has downsides: the initial setup may be slower for large models, and it cannot run in interactive processes such as Jupyter notebooks. ```python from litgpt.api import LLM if __name__ == "__main__": llm = LLM.load( model="meta-llama/Meta-Llama-3.1-8B-Instruct", distribute=None ) llm.distribute(generate_strategy="tensor_parallel", devices=4) print(llm.generate(prompt="What do llamas eat?")) print(llm.generate(prompt="What is 1+2?", top_k=1)) ```   ## Speed and resource estimates Use the `.benchmark()` method to compare the computational performance of different settings. The `.benchmark()` method takes the same arguments as the `.generate()` method. For example, we can estimate the speed and GPU memory consumption as follows (the resulting numbers were obtained on an A10G GPU): ```python from litgpt.api import LLM from pprint import pprint llm = LLM.load( model="microsoft/phi-2", distribute=None ) llm.distribute(fixed_kv_cache_size=500) text, bench_d = llm.benchmark(prompt="What do llamas eat?", top_k=1, stream=True) print(text) pprint(bench_d) # Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized # digestive system that allows them to efficiently extract nutrients from plant material. # Using 1 device(s) # Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material. # {'Inference speed in tokens/sec': [17.617540650112936], # 'Seconds to first token': [0.6533610639999097], # 'Seconds total': [1.4758019020000575], # 'Tokens generated': [26], # 'Total GPU memory allocated in GB': [5.923729408]} ``` To get more reliably estimates, it's recommended to repeat the benchmark for multiple iterations via `num_iterations=10`: ```python text, bench_d = llm.benchmark(num_iterations=10, prompt="What do llamas eat?", top_k=1, stream=True) print(text) pprint(bench_d) # Using 1 device(s) # Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material. # {'Inference speed in tokens/sec': [17.08638672485105, # 31.79908547222976, # 32.83646959864293, # 32.95994240022436, # 33.01563039816964, # 32.85263413816648, # 32.82712094713627, # 32.69216141907453, # 31.52431714347663, # 32.56752130561681], # 'Seconds to first token': [0.7278506560005553, # 0.022963577999689733, # 0.02399449199947412, # 0.022921959999621322, # ... ``` As one can see, the first iteration may take longer due to warmup times. So, it's recommended to discard the first iteration: ```python for key in bench_d: bench_d[key] = bench_d[key][1:] ``` For better visualization, you can use the `benchmark_dict_to_markdown_table` function ```python from litgpt.api import benchmark_dict_to_markdown_table print(benchmark_dict_to_markdown_table(bench_d_list)) ``` | Metric | Mean | Std Dev | |-------------------------------------|-----------------------------|-----------------------------| | Seconds total | 0.80 | 0.01 | | Seconds to first token | 0.02 | 0.00 | | Tokens generated | 26.00 | 0.00 | | Inference speed in tokens/sec | 32.56 | 0.50 | | Total GPU memory allocated in GB | 5.92 | 0.00 |   # PyTorch Lightning Trainer support You can use the LitGPT `LLM` class with the [PyTorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to pretrain and finetune models. The examples below show the usage via a simple 160 million parameter model for demonstration purposes to be able to quickly try it out. However, you can replace the `EleutherAI/pythia-160m` model with any model supported by LitGPT (you can find a list of supported models by executing `litgpt download list` or visiting the [model weight docs](download_model_weights.md)).   ## Step 1: Define a `LightningModule` First, we define a `LightningModule` similar to what we would do when working with other types of neural networks in PyTorch Lightning: ```python import torch import litgpt from litgpt import LLM from litgpt.data import Alpaca2k import lightning as L class LitLLM(L.LightningModule): def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None): super().__init__() self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None) self.trainer_ckpt_path = trainer_ckpt_path def setup(self, stage): self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path) def training_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch): logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"]) self.log("validation_loss", loss, prog_bar=True) return loss def configure_optimizers(self): warmup_steps = 10 optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps) return [optimizer], [scheduler] ``` In the code example above, note how we set `distribute=None` in `llm.load()` in the `__init__` method. This step is necessary because we want to let the PyTorch Lightning Trainer handle the GPU devices. We then call `self.llm.trainer_setup` in the `setup()` method, which adjusts the LitGPT settings to be compatible with the Trainer. Other than that, everything else looks like a standard `LightningModule`. Next, we have a selection of different use cases, but first, let's set some general settings to specify the batch size and gradient accumulation steps: ```python batch_size = 8 accumulate_grad_batches = 1 ``` For larger models, you may want to decrease the batch size and increase the number of accumulation steps. (Setting `accumulate_grad_batches = 1` effectively disables gradient accumulation, and it is only shown here for reference in case you wish to change this setting.) ## Step 2: Using the Trainer   ### Use case 1: Pretraining from random weights In case you plan to train a model from scratch (not recommended over finetuning because training a model from scratch in general requires substantial time and resources), you can do it as follows: ```python # Create model with random as opposed to pretrained weights llm = LLM.load("EleutherAI/pythia-160m", tokenizer_dir="EleutherAI/pythia-160m", init="random") llm.save("pythia-160m-random-weights") del llm lit_model = LitLLM(checkpoint_dir="pythia-160m-random-weights", tokenizer_dir="EleutherAI/pythia-160m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ```   ### Use case 2: Continued pretraining or finetuning a downloaded model The continued pretraining or finetuning from a downloaded model checkpoint is similar to the example above, except that we can skip the initial steps of instantiating a model with random weights. ```python lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m") data = Alpaca2k() data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ```   ### Use case 3: Resume training from Trainer checkpoint Suppose you trained a model and decide to follow up with a few additional training rounds. This can be achieved as follows by loading an existing Trainer checkpoint: ```python import os def find_latest_checkpoint(directory): latest_checkpoint = None latest_time = 0 for root, _, files in os.walk(directory): for file in files: if file.endswith('.ckpt'): file_path = os.path.join(root, file) file_time = os.path.getmtime(file_path) if file_time > latest_time: latest_time = file_time latest_checkpoint = file_path return latest_checkpoint lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs")) data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ```   ### Use case 4: Resume training after saving a checkpoint manually This example illustrates how we can save a LitGPT checkpoint from a previous training run that we can load and use later. Note that compared to using the Trainer checkpoint in the previous section, the model saved via this approach also contains the tokenizer and other relevant files. Hence, this approach does not require the original `"EleutherAI/pythia-160m"` model checkpoint directory. ```python lit_model.llm.save("finetuned_checkpoint") del lit_model lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint") data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512) trainer = L.Trainer( devices=1, accelerator="cuda", max_epochs=1, accumulate_grad_batches=accumulate_grad_batches, precision="bf16-true", ) trainer.fit(lit_model, data) lit_model.llm.model.to(lit_model.llm.preprocessor.device) lit_model.llm.generate("hello world") ``` ================================================ FILE: tutorials/quantize.md ================================================ # Quantize the model This document provides different strategies for quantizing the various models available in LitGPT to reduce GPU memory usage, which is useful for running larger models on certain GPU hardware. **All the examples below were run on an A100 40GB GPU with CUDA 12.1.** > [!NOTE] > Quantization also supports finetuning via [QLoRA](finetune_lora.md) ## Baseline It's useful to start with a baseline to have a reference point for memory savings via the various quantization methods. ```bash litgpt generate tiiuae/falcon-7b \ --precision 32-true \ --max_new_tokens 256 ... Time for inference 1: 6.93 sec total, 36.96 tokens/sec. Memory used: 28.95 GB ``` First, using a lower precision compared to 32-bit float can result in two times reduced memory consumption. You can either try setting `--precision 16-true` for regular 16-bit precision or `--precision bf16-true` if your GPU supports brain-float 16-bit precision. ([This brief video](https://lightning.ai/courses/deep-learning-fundamentals/9.0-overview-techniques-for-speeding-up-model-training/unit-9.1-accelerated-model-training-via-mixed-precision-training/) explains the difference between regular 16-bit and bf16-bit precision.) In short, when `--precision bf16-true` or `--precision 16-true` is used, the model weights will automatically be converted and consume less memory. However, this might not be enough for large models or when using GPUs with limited memory. ```bash litgpt generate tiiuae/falcon-7b \ --precision bf16-true \ --max_new_tokens 256 ... Time for inference 1: 5.37 sec total, 47.66 tokens/sec. Memory used: 14.50 GB ``` To reduce the memory requirements further, LitGPT supports several quantization techniques, which are shown below. > [!TIP] > Most quantization examples below also use the `--precision bf16-true` setting explained above. If your GPU does not support `bfloat16`, you can change it to `--precision 16-true`. ## `bnb.nf4` Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works. > [!IMPORTANT] > `bitsandbytes` only supports `CUDA` devices and the `Linux` operating system. > Windows users should use [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl). Uses the normalized float 4 (nf4) data type. This is recommended over "fp4" based on the paper's experimental results and theoretical analysis. ```bash pip install bitsandbytes litgpt generate tiiuae/falcon-7b \ --quantize bnb.nf4 \ --precision bf16-true \ --max_new_tokens 256 ... Time for inference 1: 6.80 sec total, 37.62 tokens/sec Memory used: 5.72 GB ``` ## `bnb.nf4-dq` Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works. "dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants. In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model). ```bash pip install bitsandbytes litgpt generate tiiuae/falcon-7b \ --quantize bnb.nf4-dq \ --precision bf16-true \ --max_new_tokens 256 ... Time for inference 1: 8.09 sec total, 30.87 tokens/sec Memory used: 5.38 GB ``` ## `bnb.fp4` Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works. Uses pure FP4 quantization. ```bash pip install bitsandbytes litgpt generate tiiuae/falcon-7b \ --quantize bnb.fp4 \ --precision bf16-true \ --max_new_tokens 256 ... Time for inference 1: 6.92 sec total, 36.98 tokens/sec Memory used: 5.72 GB ``` ## `bnb.fp4-dq` Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works. "dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants. In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model). ```bash pip install bitsandbytes litgpt generate tiiuae/falcon-7b \ --quantize bnb.fp4-dq \ --precision bf16-true \ --max_new_tokens 256 ... Time for inference 1: 10.02 sec total, 25.54 tokens/sec Memory used: 5.38 GB ``` ## `bnb.int8` Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2110.02861) to learn more about how it works. ```bash pip install bitsandbytes litgpt generate tiiuae/falcon-7b \ --quantize bnb.int8 \ --precision 16-true \ --max_new_tokens 256 ... Time for inference 1: 20.22 sec total, 12.66 tokens/sec Memory used: 8.70 GB ``` ================================================ FILE: tutorials/resource-tables.md ================================================ # Resource Tables - Last updated: 10/20/2023 - LitGPT version: commit 8641822 - Hardware: NVIDIA A100-SXM4-40GB - OS: Ubuntu 22.04.3 LTS (x86_64) - Nvidia driver version: 525.125.06 - Relevant libraries - PyTorch 2.1.0+cu121 - Bitsandbytes 0.41.1 This document provides an overview and examples of hardware requirements when running models in LitGPT. For additional tips on lowering the GPU memory footprint, please also see the [Dealing with out-of-memory (OOM) errors](oom.md) document. All experiments were run using 16-bit brain floating point precision (`--precision bf16-true`). If your GPU does not support brain floating point precision, you can use regular 16-bit floating point precision (`--precision 16-true`). All experiments were conducted using the Alpaca dataset with its default length. Note that due to different tokenizers being used by the different models, the number of tokens in the longest training example differs based on the model: - phi1.5: 1044 tokens - StableLM Alpha: 1034 tokens - Llama 2: 1304 tokens - Falcon 1079 tokens Note that the number of tokens in the training set does not affect the supported context width (block size) of the models, which is as follows: - phi1.5: 2048 tokens - StableLM 3B Alpha: 4096 tokens - Llama 2: 4048 tokens - Falcon: 2048 tokens - CodeLlama 13B: 16384 tokens   ## Finetuning with LoRA on 1 GPU The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_lora` command. | Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | |-------|----------------|--------------|-----------------|----------------------|-------------|--------------------| | 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 4.82 GB | 1.62 min | | 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 3.78 GB | 1.77 min | | 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 3.72 GB | 1.87 min | | 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 6.76 GB | 1.65 min | | 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 10.68 GB | 1.70 min | | | | | | | | | | 3 B | StableLM Alpha | None | 1 | 2,097,152 | 9.69 GB | 1.24 min | | 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,097,152 | 6.35 GB | 1.82 min | | 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,097,152 | 6.19 GB | 1.87 min | | 3 B | StableLM Alpha | None | 2 | 2,097,152 | 12.10 GB | 1.33 min | | 3 B | StableLM Alpha | None | 4 | 2,097,152 | 16.92 GB | 1.50 min | | | | | | | | | | 7 B | Llama 2 | None | 1 | 4,194,304 | 21.30 GB | 2.36 min | | 7 B | Llama 2 | bnb.nf4 | 1 | 4,194,304 | 14.14 GB | 3.68 min | | 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,194,304 | 13.84 GB | 3.83 min | | 7 B | Llama 2 | None | 2 | 4,194,304 | 29.07 GB | 2.52 min | | 7 B | Llama 2 | None | 4 | 4,194,304 | OOM | - | | | | | | | | | | 13 B | Llama 2 | None | 1 | 6,553,600 | 38.12 GB | 3.19 min | | 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 23.14 GB | 6.38 min | | 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 22.55 GB | 6.55 min | | 13 B | Llama 2 | None | 2 | 6,553,600 | OOM | - | | 13 B | Llama 2 | None | 4 | 6,553,600 | OOM | - | | | | | | | | | | 40 B | Falcon | None | 1 | 12,042,240 | OOM | - | | 40 B | Falcon | bnb.nf4 | 1 | 12,042,240 | OOM | - | | 40 B | Falcon | bnb.nf4-dq | 1 | 12,042,240 | OOM | - |   ## Finetuning with Adapter on 1 GPU The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_adapter` command. | Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | |------|----------------|--------------|-----------------|----------------------|-------------|--------------------| | 3 B | StableLM Alpha | None | 1 | 573,888 | 9.10 GB | 0.74 min | | 3 B | StableLM Alpha | bnb.nf4 | 1 | 573,888 | 5.65 GB | 1.38 min | | 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 573,888 | 5.48 GB | 1.46 min | | | | | | | | | | 7 B | Llama 2 | None | 1 | 1,229,760 | 19.98 GB | 1.50 min | | 7 B | Llama 2 | bnb.nf4 | 1 | 1,229,760 | 12.68 GB | 2.93 min | | 7 B | Llama 2 | bnb.nf4-dq | 1 | 1,229,760 | 12.38 GB | 3.00 min | The same config, but using the `litgpt finetune_adapter_v2` command. | Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations | |------|----------------|--------------|-----------------|----------------------|-------------|--------------------| | 3 B | StableLM Alpha | None | 1 | 2,125,248 | 10.71 GB | 0.87 min | | 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,125,248 | 7.41 GB | 1.59 min | | 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,125,248 | 7.25 GB | 1.62 min | | | | | | | | | | 7 B | Llama 2 | None | 1 | 4,279,744 | 25.51 GB | 1.81 min | | 7 B | Llama 2 | bnb.nf4 | 1 | 4,279,744 | 18.30 GB | 3.23 min | | 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,279,744 | 17.98 GB | 3.32 min |   ## Finetuning with LoRA on Multiple GPUs The following experiments were conducted on multiple A100 GPUs with a minibatch size of 128 using the `litgpt finetune_lora` command. | Size | Model | Quantization | Microbatch size | Trainable parameters | GPU | Max GPU RAM | Time 1k iterations | |-------|----------------|--------------|-----------------|----------------------|----------|-------------|--------------------| | 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 2 x A100 | 4.86 GB | 3.81 min | | 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 2 x A100 | N/A | - | | 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 2 x A100 | N/A | - | | 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 2 x A100 | 5.05 GB | 3.63 min | | 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 2 x A100 | 5.88 GB | 3.64 min | | | | | | | | | | | 3 B | StableLM Alpha | None | 1 | 2,097,152 | 2 x A100 | 12.75 GB | 2.92 min | | 3 B | StableLM Alpha | None | 2 | 2,097,152 | 2 x A100 | 12.94 GB | 3.06 min | | 3 B | StableLM Alpha | None | 4 | 2,097,152 | 2 x A100 | 13.45 GB | 3.86 min | | | | | | | | | - | | 7 B | Llama 2 | None | 1 | 4,194,304 | 2 x A100 | 22.18 GB | 5.93 min | | 7 B | Llama 2 | None | 2 | 4,194,304 | 2 x A100 | 22.47 GB | 6.48 min | | 7 B | Llama 2 | None | 4 | 4,194,304 | 2 x A100 | 23.39 GB | 8.66 min | | | | | | | | | | | 13 B | Llama 2 | None | 1 | 6,553,600 | 2 x A100 | OOM | - | | 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 2 x A100 | N/A | - | | 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 2 x A100 | N/A | - | | | | | | | | | | | 13 B | Llama 2 | None | 1 | 6,553,600 | 4 x A100 | 35.57 GB | 10.25 min | | 40 B | Falcon | None | 1 | 12,042,240 | 4 x A100 | OOM | - |   ## Single-GPU Inference | Size | Model | Quantization | GPU | Max GPU RAM | Token/sec | |-------|----------------|--------------|----------|-------------------------------------------|-----------| | 1.3 B | phi-1.5 | None | 1 x A100 | 2.86 GB | 42.56 | | 1.3 B | phi-1.5 | bnb.nf4 | 1 x A100 | 1.39 GB | 22.89 | | 1.3 B | phi-1.5 | bnb.nf4-dq | 1 x A100 | 1.33 GB | 22.75 | | | | | | | | | 3 B | StableLM Alpha | None | 1 x A100 | 7.30 GB | 49.01 | | 3 B | StableLM Alpha | bnb.nf4 | 1 x A100 | 3.20 GB | 29.04 | | 3 B | StableLM Alpha | bnb.nf4-dq | 1 x A100 | 3.04 GB | 27.15 | | | | | | | | | 7 B | Llama 2 | None | 1 x A100 | 13.52 GB | 30.97 | | 7 B | Llama 2 | bnb.nf4 | 1 x A100 | 4.57 GB | 19.98 | | 7 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 4.26 GB | 17.3 | | | | | | | | | 13 B | Llama 2 | None | 1 x A100 | 26.21 GB | 24.82 | | 13 B | Llama 2 | bnb.nf4 | 1 x A100 | 8.32 GB | 16.73 | | 13 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 7.72 GB | 14.43 | | | | | | | | | 34 B | CodeLlama | None | 1 x A100 | OOM | - | | 34 B | CodeLlama | bnb.nf4 | 1 x A100 | 20.52 GB | 14.32 | | 34 B | CodeLlama | bnb.nf4-dq | 1 x A100 | 18.95 GB | 12.37 | | | | | | | | | 40 B | Falcon | None | 1 x A100 | OOM | - | | 40 B | Falcon | bnb.nf4 | 1 x A100 | 26.55 GB | 13.25 | | 40 B | Falcon | bnb.nf4-dq | 1 x A100 | 24.63 GB | 11.64 | | | | | | | | | 70 B | Llama 2 | None | 1 x A100 | OOM | - | | 70 B | Llama 2 | bnb.nf4 | 1 x A100 | CUDA error: CUBLAS_STATUS_NOT_INITIALIZED | - | | 70 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 37.21 GB | 7.97 |