Repository: Lightning-AI/litgpt
Branch: main
Commit: 162ad9bee317
Files: 233
Total size: 1.8 MB
Directory structure:
gitextract_ctr2cg_x/
├── .devcontainer/
│ ├── Dockerfile
│ └── devcontainer.json
├── .github/
│ ├── CODEOWNERS
│ ├── ISSUE_TEMPLATE/
│ │ ├── ask-a-question.md
│ │ ├── bug-report.yaml
│ │ └── feature-request.md
│ ├── dependabot.yml
│ └── workflows/
│ ├── check-links.yml
│ ├── cpu-tests.yml
│ ├── mkdocs-deploy.yml
│ └── publish-pkg.yml
├── .gitignore
├── .lightning/
│ └── workflows/
│ └── tests.yaml
├── .pre-commit-config.yaml
├── CITATION.cff
├── LICENSE
├── README.md
├── config_hub/
│ ├── finetune/
│ │ ├── README.md
│ │ ├── falcon-7b/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── gemma-2b/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── gemma-7b/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── gemma2-2b/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── gemma2-9b/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── llama-2-7b/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── llama-3-8b/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── llama-3.1-8b/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── llama-3.2-1B/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── llama-3.2-3B/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── mistral-7b/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── mistral-7b-v0.2/
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── phi-2/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── phi-3/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ ├── stablelm-base-alpha-3b/
│ │ │ ├── full.yaml
│ │ │ ├── lora.yaml
│ │ │ └── qlora.yaml
│ │ └── tiny-llama/
│ │ ├── full.yaml
│ │ ├── lora.yaml
│ │ └── qlora.yaml
│ └── pretrain/
│ ├── debug.yaml
│ ├── microllama.yaml
│ ├── tinyllama.yaml
│ └── tinystories.yaml
├── extensions/
│ ├── thunder/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── pretrain.py
│ │ ├── strategies/
│ │ │ ├── __init__.py
│ │ │ ├── thunder_ddp.py
│ │ │ └── thunder_fsdp.py
│ │ └── unsloth/
│ │ ├── __init__.py
│ │ ├── executor.py
│ │ └── kernels/
│ │ ├── __init__.py
│ │ ├── cross_entropy_loss.py
│ │ ├── rope_embedding.py
│ │ ├── swiglu.py
│ │ └── utils.py
│ └── xla/
│ ├── README.md
│ ├── __init__
│ ├── finetune/
│ │ ├── __init__
│ │ └── adapter.py
│ ├── generate/
│ │ ├── __init__
│ │ ├── adapter.py
│ │ └── base.py
│ ├── scripts/
│ │ ├── __init__
│ │ └── prepare_alpaca.py
│ └── utils.py
├── litgpt/
│ ├── __init__.py
│ ├── __main__.py
│ ├── adapter.py
│ ├── adapter_v2.py
│ ├── api.py
│ ├── args.py
│ ├── chat/
│ │ ├── __init__.py
│ │ └── base.py
│ ├── config.py
│ ├── constants.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── alpaca.py
│ │ ├── alpaca_2k.py
│ │ ├── alpaca_gpt4.py
│ │ ├── base.py
│ │ ├── deita.py
│ │ ├── flan.py
│ │ ├── json_data.py
│ │ ├── lima.py
│ │ ├── lit_data.py
│ │ ├── longform.py
│ │ ├── microllama.py
│ │ ├── openwebtext.py
│ │ ├── prepare_slimpajama.py
│ │ ├── prepare_starcoder.py
│ │ ├── text_files.py
│ │ ├── tinyllama.py
│ │ └── tinystories.py
│ ├── deploy/
│ │ ├── __init__.py
│ │ └── serve.py
│ ├── eval/
│ │ └── evaluate.py
│ ├── finetune/
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── adapter_v2.py
│ │ ├── full.py
│ │ ├── lora.py
│ │ └── lora_legacy.py
│ ├── generate/
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── adapter_v2.py
│ │ ├── base.py
│ │ ├── full.py
│ │ ├── sequentially.py
│ │ ├── speculative_decoding.py
│ │ └── tp.py
│ ├── lora.py
│ ├── model.py
│ ├── parser_config.py
│ ├── pretrain.py
│ ├── prompts.py
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── convert_hf_checkpoint.py
│ │ ├── convert_lit_checkpoint.py
│ │ ├── convert_pretrained_checkpoint.py
│ │ ├── download.py
│ │ └── merge_lora.py
│ ├── tokenizer.py
│ ├── types.py
│ └── utils.py
├── pyproject.toml
├── tests/
│ ├── conftest.py
│ ├── convert/
│ │ ├── __init__.py
│ │ ├── test_hf_checkpoint.py
│ │ ├── test_lit_checkpoint.py
│ │ └── test_pretrained_checkpoint.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── _fixtures/
│ │ │ ├── alpaca.json
│ │ │ ├── dolly.json
│ │ │ ├── longform_train.json
│ │ │ └── longform_val.json
│ │ ├── test_alpaca.py
│ │ ├── test_base.py
│ │ ├── test_deita.py
│ │ ├── test_json.py
│ │ ├── test_lit_data.py
│ │ ├── test_longform.py
│ │ ├── test_openwebtext.py
│ │ ├── test_textfiles.py
│ │ ├── test_tinyllama.py
│ │ └── test_tinystories.py
│ ├── ext_thunder/
│ │ ├── __init__.py
│ │ ├── test_thunder_distributed.py
│ │ ├── test_thunder_networks.py
│ │ ├── test_thunder_pretrain.py
│ │ └── test_unsloth_executor.py
│ ├── generate/
│ │ ├── __init__.py
│ │ ├── test_adapter.py
│ │ ├── test_main.py
│ │ ├── test_sequentially.py
│ │ ├── test_tp.py
│ │ └── utils.py
│ ├── test_adapter.py
│ ├── test_adapter_v2.py
│ ├── test_api.py
│ ├── test_args.py
│ ├── test_batch.py
│ ├── test_chat.py
│ ├── test_ci.py
│ ├── test_cli.py
│ ├── test_config.py
│ ├── test_config_hub.py
│ ├── test_deepseek_moe.py
│ ├── test_distributed.py
│ ├── test_evaluate.py
│ ├── test_full.py
│ ├── test_generate_speculatively.py
│ ├── test_lora.py
│ ├── test_merge_lora.py
│ ├── test_model.py
│ ├── test_multihead_latent_attention.py
│ ├── test_pretrain.py
│ ├── test_prompts.py
│ ├── test_readme.py
│ ├── test_rope.py
│ ├── test_serve.py
│ ├── test_tokenizer.py
│ ├── test_trainer_support.py
│ ├── test_types.py
│ ├── test_utils.py
│ └── test_yarn.py
└── tutorials/
├── 0_to_litgpt.md
├── convert_hf_checkpoint.md
├── convert_lit_models.md
├── deploy.md
├── developer-docs/
│ ├── README.md
│ ├── adding-models.md
│ └── python-api.md
├── download_model_weights.md
├── evaluation.md
├── examples/
│ └── ptl-trainer/
│ ├── README.md
│ ├── litgpt_ptl_medium.py
│ └── litgpt_ptl_small.py
├── finetune.md
├── finetune_adapter.md
├── finetune_full.md
├── finetune_lora.md
├── full_finetune_example.py
├── inference.md
├── mkdocs.yml
├── oom.md
├── prepare_dataset.md
├── pretrain.md
├── pretrain_tinyllama.md
├── python-api.md
├── quantize.md
└── resource-tables.md
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/Dockerfile
================================================
# See here for image contents: https://github.com/devcontainers/images/blob/main/src/python/.devcontainer/Dockerfile
# [Choice] Python version (use -bookworm or -bullseye variants on local arm64/Apple Silicon): 3, 3.12, 3.11, 3.10, 3.9, 3.8, 3-bookworm, 3.12-bookworm, 3.11-bookworm, 3.10-bookworm, 3.9-bookworm, 3.8-bookworm, 3-bullseye, 3.12-bullseye, 3.11-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3-buster, 3.12-buster, 3.11-buster, 3.10-buster, 3.9-buster, 3.8-buster
ARG VARIANT=3-bookworm
FROM mcr.microsoft.com/devcontainers/python:1-${VARIANT}
# Temporary: Upgrade python packages due to https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-40897
# They are installed by the base image (python) which does not have the patch.
RUN python3 -m pip install --upgrade pip setuptools
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
// https://github.com/microsoft/vscode-dev-containers/tree/v0.194.0/containers/python-3
{
"name": "Python 3 (litgpt)",
"build": {
"dockerfile": "Dockerfile",
"context": "..",
"args": {
"VARIANT": "3.11-bookworm"
}
},
"runArgs": [
// Enable GPU passthrough, requires WSL2 on Windows
//"--gpus=all",
// One of the following options is required for torch multiprocessing
//"--ipc=host",
//"--shm-size=4gb",
],
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
"ghcr.io/devcontainers/features/git:1": {},
"ghcr.io/devcontainers/features/git-lfs:1": {},
//"ghcr.io/devcontainers/features/nvidia-cuda:1": {},
"ghcr.io/devcontainers-extra/features/actionlint:1": {},
"ghcr.io/devcontainers-extra/features/pre-commit:2": {},
"ghcr.io/dhoeric/features/act:1": {},
"ghcr.io/devcontainers/features/docker-in-docker:2": {
"version": "latest",
"moby": true
}
},
// Set *default* container specific settings.json values on container create.
"customizations": {
"vscode": {
"settings": {
"editor.tabSize": 4,
"editor.renderWhitespace": "all",
"editor.formatOnSave": true,
"editor.rulers": [120],
"files.exclude": {
"**/__pycache__": true
},
"python.pythonPath": "/usr/local/bin/python",
"python.defaultInterpreterPath": "/usr/local/bin/python",
"python.languageServer": "Pylance",
"python.analysis.autoImportCompletions": true,
"python.analysis.completeFunctionParens": true,
"python.analysis.autoSearchPaths": true,
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"code-eol.highlightNonDefault": true,
"code-eol.highlightExtraWhitespace": true,
"autoDocstring.docstringFormat": "google-notypes",
"autoDocstring.guessTypes": true,
"autoDocstring.generateDocstringOnEnter": true,
"autoDocstring.startOnNewLine": true,
"telemetry.telemetryLevel": "off",
"[python]": {
"editor.formatOnSave": true,
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.codeActionsOnSave": {
"source.organizeImports": "always",
"source.fixAll": "always"
}
}
},
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-toolsai.jupyter",
"GitHub.copilot",
"GitHub.copilot-chat",
"github.vscode-github-actions",
"SanjulaGanepola.github-local-actions",
"charliermarsh.ruff",
"esbenp.prettier-vscode",
"ms-vscode.test-adapter-converter",
"njqdev.vscode-python-typehint",
"KevinRose.vsc-python-indent",
"medo64.render-crlf",
"shardulm94.trailing-spaces",
"nhoizey.gremlins",
"wayou.vscode-todo-highlight",
"Gruntfuggly.todo-tree",
"njpwerner.autodocstring",
"rodolphebarbanneau.python-docstring-highlighter",
"mechatroner.rainbow-csv",
"uctakeoff.vscode-counter",
"bierner.github-markdown-preview",
"yahyabatulu.vscode-markdown-alert",
"ms-vscode-remote.vscode-remote-extensionpack",
"ms-azuretools.vscode-docker",
"redhat.vscode-yaml"
]
}
},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "pre-commit install && pip install '.[extra,compiler,test]' -U",
// Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
"remoteUser": "vscode"
}
================================================
FILE: .github/CODEOWNERS
================================================
* @lantiga @t-vi @lianakoleva @KaelanDt @k223kim @andyland
/README.md @williamfalcon @lantiga @lianakoleva
================================================
FILE: .github/ISSUE_TEMPLATE/ask-a-question.md
================================================
---
name: Ask a Question
about: Ask and answer questions related to LitGPT
title: ''
labels: question
---
Please describe your question here.
================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.yaml
================================================
name: Bug Report
description: Report errors related to LitGPT
title: "Description"
labels: bug
body:
- type: markdown
attributes:
value: |
Thank you for taking the time to report an issue. Please fill out the details below to help us resolve it.
- type: textarea
id: bug_description
attributes:
label: Bug description
description: A description of the issue.
placeholder: |
Please provide a description of what the bug or issue is.
validations:
required: true
- type: input
attributes:
label: Reproduced in studio
description: >
Create a new Lightning Studio with code that reproduces the issue and share the link.
Also include all the relevant files and data required to reproduce shared issue.
In case the code does not crash, please add assert statements to show what is the real and expected output.
A simple guide on how to create such a studio can be found [here](https://www.youtube.com/watch?v=YcW-2Zt_bFg&ab_channel=LightningAI).
placeholder: https://lightning.ai/...
validations:
required: false
- type: dropdown
id: operating_system
attributes:
label: What operating system are you using?
description: If applicable, please select the operating system where you experienced this issue.
options:
- "Unknown"
- "macOS"
- "Linux"
- "Windows"
validations:
required: true
- type: textarea
id: version
attributes:
label: LitGPT Version
description: |
Please provide details about your LitGPT version by running the following code in your terminal:
```
pip show litgpt | grep Version:
```
validations:
required: false
================================================
FILE: .github/ISSUE_TEMPLATE/feature-request.md
================================================
---
name: Suggest a Feature
about: Propose a new feature or enhancement
title: ''
labels: enhancement
---
Please describe the feature or enhancement along with the intended usecase.
================================================
FILE: .github/dependabot.yml
================================================
# Basic dependabot.yml file with
# minimum configuration for two package managers
version: 2
updates:
# Enable version updates for python
- package-ecosystem: "pip"
# Look for a `requirements` in the `root` directory
directory: "/"
# Check for updates once a week
schedule:
interval: "monthly"
# Labels on pull requests for version updates only
labels:
- "dependencies"
pull-request-branch-name:
# Separate sections of the branch name with a hyphen
# for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
separator: "-"
# Allow up to 5 open pull requests for pip dependencies
open-pull-requests-limit: 3
# Enable version updates for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
# Check for updates once a week
schedule:
interval: "weekly"
# Labels on pull requests for version updates only
labels:
- "CI / actions"
pull-request-branch-name:
# Separate sections of the branch name with a hyphen
# for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
separator: "-"
# Allow up to 5 open pull requests for GitHub Actions
open-pull-requests-limit: 1
groups:
GHA-updates:
patterns:
- "*"
================================================
FILE: .github/workflows/check-links.yml
================================================
name: Check hyperlinks
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install "mistune<3.1" # a newer version is incompatible with nbconvert
pip install pytest pytest-check-links
- name: Check links
run: |
pytest --check-links README.md --check-links-ignore "http*"
pytest --check-links tutorials --check-links-ignore "http*"
================================================
FILE: .github/workflows/cpu-tests.yml
================================================
name: CPU tests
on:
push:
branches: [main]
pull_request_target:
branches: [main]
types: [opened, reopened, ready_for_review, labeled, synchronize]
pull_request: {} # todo
workflow_dispatch: {}
# lock down all permissions by default
permissions:
contents: read # needed to check out code
checks: write # needed for test results
pull-requests: read # needed for PR metadata
actions: read # needed to use actions
security-events: none
statuses: write # needed to update commit status
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
env:
HF_HOME: .cache-HF # Define HF_HOME for caching
TRANSFORMERS_CACHE: .cache-HF/transformers
DATASETS_CACHE: .cache-HF/datasets
HF_DATASETS_CACHE: .cache-HF/datasets
TORCH_URL: "https://download.pytorch.org/whl/cpu/"
jobs:
testing-imports:
runs-on: ${{ matrix.os }}
if: github.event_name != 'pull_request_target'
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04", "ubuntu-24.04", "macOS-14", "windows-2022"]
python-version: ["3.10"]
timeout-minutes: 10
steps:
- name: Checkout generic
uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install minimal dependencies
run: |
pip install . -U --extra-index-url="${TORCH_URL}"
pip list
- name: Testing package imports
# make sure all modules are still importable with only the minimal dependencies available
run: |
modules=$(
find litgpt -type f -name "*.py" | \
sed 's/\.py$//' | sed 's/\//./g' | \
sed 's/.__init__//g' | xargs -I {} echo "import {};"
)
echo "$modules"
python -c "$modules"
pytester:
# Route PRs based on contributor type to avoid duplicate runs:
# - Collaborators: use pull_request (tests workflow changes from PR)
# - External forks: use pull_request_target (uses trusted workflow from main)
# - Always run for push to main and workflow_dispatch
if: |
(github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||
(github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||
(github.event_name != 'pull_request' && github.event_name != 'pull_request_target')
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
requires: ["latest"]
include:
- { os: "ubuntu-22.04", python-version: "3.10", requires: "oldest" }
- { os: "windows-2022", python-version: "3.10", requires: "latest" }
- { os: "macOS-14", python-version: "3.10", requires: "latest" }
timeout-minutes: 35
steps:
- name: Checkout generic
uses: actions/checkout@v6
if: github.event_name != 'pull_request_target'
- name: Checkout for `pull_request_target`
uses: actions/checkout@v6
if: github.event_name == 'pull_request_target'
with:
ref: ${{ github.event.pull_request.head.sha }}
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
cache-dependency-path: pyproject.toml
cache: "pip"
# Add caching for HF models and tokenizers
- name: HF cache
uses: actions/cache@v5
continue-on-error: true
with:
path: .cache-HF
key: hf-cache_${{ runner.os }}-py${{ matrix.python-version }}
restore-keys: |
hf-cache_${{ runner.os }}-py${{ matrix.python-version }}
hf-cache_${{ runner.os }}-
hf-cache_
- name: Set min. dependencies
if: matrix.requires == 'oldest'
run: |
pip install 'lightning-utilities[cli]>=0.15.1'
python -m lightning_utilities.cli requirements set-oldest --req_files=pyproject.toml
- name: Install dependencies
run: |
pip install '.[extra,compiler,test]' -U --upgrade-strategy eager --extra-index-url="${TORCH_URL}"
pip list
- name: Run tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: pytest -v litgpt/ tests/ --timeout=180 --durations=100
- name: Show cache
run: |
pip install -q py-tree
python -m py_tree -d 1 .cache-HF
testing-guardian:
runs-on: ubuntu-latest
needs: [pytester, testing-imports]
if: |
(github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||
(github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association))
steps:
- run: echo "${{ needs.pytester.result }}"
- name: failing...
if: needs.pytester.result == 'failure'
run: exit 1
- name: cancelled or skipped...
if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result)
timeout-minutes: 1
run: sleep 90
================================================
FILE: .github/workflows/mkdocs-deploy.yml
================================================
name: Deploy MkDocs
on:
push:
branches: [main]
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-24.04
steps:
# Step 1: Checkout the repository
- uses: actions/checkout@v6
# Step 2: Set up Python
- uses: actions/setup-python@v6
with:
python-version: "3.x"
cache: "pip"
# Step 3: Install MkDocs and dependencies
- run: pip install mkdocs mkdocs-material mkdocs-pagetree-plugin
# Step 4: Deploy to GitHub Pages
- run: |
mkdir -p gh-pages/docs
cp -r tutorials/* gh-pages/docs
cd gh-pages
mv docs/mkdocs.yml mkdocs.yml
echo "{{ pagetree }}" > docs/index.md
mkdocs gh-deploy --force
================================================
FILE: .github/workflows/publish-pkg.yml
================================================
# To create a release, create a tag and push it to GitHub:
#git tag -a "v0.0.1-beta" -m "beta version testing"
#git push --tags
# https://dev.to/iamtekson/publish-package-to-pypi-and-release-new-version-using-github-actions-108k
name: Publish LitGPT to PyPI
on:
push:
tags:
- "v*"
jobs:
build-n-publish:
name: Build and publish to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/litgpt
permissions:
id-token: write
steps:
- name: Checkout source
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.x"
cache: "pip"
- name: Build source and wheel distributions
run: |
python -m pip install --upgrade build twine
pip install importlib_metadata==7.2.1
python -m build
twine check --strict dist/*
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
.ipynb_checkpoints/
__pycache__
.idea
.DS_Store
*.egg-info
build
dist
.venv
.venv/
.vscode
uv.lock
# data
data
datasets
!litgpt/data
!tests/data
checkpoints
out
wandb
events.out.tfevents*
# test artifacts from tests/test_readme.py
**/custom_finetuning_dataset.json
client.py
**/custom_texts/
================================================
FILE: .lightning/workflows/tests.yaml
================================================
trigger:
push:
branches: ["main"]
pull_request:
branches: ["main"]
image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.8.1-cudnn-fe1.15.0-py3.12-pt_2.8.0-dev"
machine: "L4_X_2"
interruptible: "true"
timeout: "45" # minutes
parametrize:
matrix:
dependency: ["", "compiler"]
include: []
exclude: []
env:
SKIP_WITH_CI: "1" # skip single tests with CI
NCCL_DEBUG: "INFO"
CUBLAS_WORKSPACE_CONFIG: ":4096:8"
NCCL_IGNORE_DISABLED_P2P: "1"
TORCH_VERSION: "2.8.0"
RUN_ONLY_CUDA_TESTS: "1" # run CUDA tests only
run: |
whereis nvidia
nvidia-smi
python --version
pip --version
pip list
set -ex
echo "Install uv and create virtual environment"
curl -LsSf https://astral.sh/uv/install.sh | sh
[ -f "$HOME/.local/bin/env" ] && . "$HOME/.local/bin/env"
export PATH="$HOME/.local/bin:$PATH"
uv venv .venv --system-site-packages
. .venv/bin/activate
hash -r
uv pip install -q '.[extra,test]' "torch==${TORCH_VERSION}" cffi -U
if [ "${dependency}" == "compiler" ]; then
uv pip uninstall torchvision torchaudio
uv pip install -q '.[compiler,extra,test]' "torch==${TORCH_VERSION}"
python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'"
python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'"
fi
uv pip list
python -c "import torch ; gpus = torch.cuda.device_count() ; assert gpus >= 2, f'GPU: {gpus}'"
python -c "from torch import __version__ as ver ; assert str(ver).split('+')[0] == '${TORCH_VERSION}', f'PyTorch: installed {ver} but expected ${TORCH_VERSION}'"
pytest -v --durations=100
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
PL_RUN_STANDALONE_TESTS=1 bash run_standalone_tests.sh "tests"
if [ "${dependency}" == "compiler" ]; then
uv pip uninstall lightning-thunder transformers
# install thunder from source, so that, thunder.tests will be available
uv pip install -U "lightning-thunder[test] @ git+https://github.com/Lightning-AI/lightning-thunder.git" "torch==${TORCH_VERSION}"
# Pin transformers to match thunder's test_networks.py requirements
# See: https://github.com/Lightning-AI/lightning-thunder/blob/main/requirements/test.txt
# Get transformers version from thunder requirements
TRANSFORMERS_VERSION=$(curl -fsSL https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/main/requirements/test.txt \
| grep '^transformers==' \
| cut -d'=' -f3 \
| cut -d'#' -f1 \
| xargs)
if [ -z "${TRANSFORMERS_VERSION}" ]; then
echo "Error: Could not determine transformers version from lightning-thunder requirements"
exit 1
fi
uv pip install transformers==${TRANSFORMERS_VERSION}
# without env var, it filters out all tests
RUN_ONLY_CUDA_TESTS=0 pytest tests/ext_thunder/test_thunder_networks.py -v
fi
================================================
FILE: .pre-commit-config.yaml
================================================
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
default_language_version:
python: python3
ci:
autofix_prs: true
autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
autoupdate_schedule: quarterly
# submodules: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: README.md
- id: check-yaml
- id: check-toml
#- id: check-docstring-first
#- id: check-executables-have-shebangs
- id: check-case-conflict
- id: check-added-large-files
args: ["--maxkb=250", "--enforce-all"]
- id: detect-private-key
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: [tomli]
args: ["--write-changes"]
exclude: pyproject.toml
#- repo: https://github.com/crate-ci/typos
# rev: dictgen-v0.3.1
# hooks:
# - id: typos
# args: [] # empty to do not write fixes
# exclude: pyproject.toml
#- repo: https://github.com/executablebooks/mdformat
# rev: 0.7.21
# hooks:
# - id: mdformat
# args: ["--number"]
# additional_dependencies:
# - mdformat-gfm
# - mdformat-black
# - mdformat_frontmatter
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
hooks:
- id: prettier
files: \.(json|yml|yaml|toml)
# https://prettier.io/docs/en/options.html#print-width
args: ["--print-width=140"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.10
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format
- id: ruff
- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.11.1
hooks:
- id: pyproject-fmt
additional_dependencies: [tox]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.24.1
hooks:
- id: validate-pyproject
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, you can cite it as shown below."
title: "LitGPT"
abstract: "20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale."
date-released: 2023-03-22
authors:
- name: "The Lightning AI team"
license: "Apache-2.0"
url: "https://github.com/Lightning-AI/litgpt"
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2023] Lightning AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# ⚡ LitGPT
**20+ high-performance LLMs with recipes to pretrain, finetune, and deploy at scale.**
✅ From scratch implementations ✅ No abstractions ✅ Beginner friendly
✅ Flash attention ✅ FSDP ✅ LoRA, QLoRA, Adapter
✅ Reduce GPU memory (fp4/8/16/32) ✅ 1-1000+ GPUs/TPUs ✅ 20+ LLMs
---

 [](https://github.com/Lightning-AI/lit-stablelm/blob/master/LICENSE) [](https://discord.gg/VptPCZkGNa)
Quick start •
Models •
Finetune •
Deploy •
All workflows •
Features •
Recipes (YAML) •
Lightning AI •
Tutorials
# Looking for GPUs?
Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) - purpose-built for PyTorch and PyTorch Lightning.
- [GPUs](https://lightning.ai/pricing?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) from $0.19.
- [Clusters](https://lightning.ai/clusters?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): frontier-grade training/inference clusters.
- [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you debug, tune and vibe train.
- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you optimize, and deploy models.
- [Notebooks](https://lightning.ai/notebooks?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Persistent GPU workspaces where AI helps you code and analyze.
- [Inference](https://lightning.ai/deploy?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Deploy models as inference APIs.
# Finetune, pretrain, and inference LLMs Lightning fast ⚡⚡
Every LLM is implemented from scratch with **no abstractions** and **full control**, making them blazing fast, minimal, and performant at enterprise scale.
✅ **Enterprise ready -** Apache 2.0 for unlimited enterprise use.
✅ **Developer friendly -** Easy debugging with no abstraction layers and single file implementations.
✅ **Optimized performance -** Models designed to maximize performance, reduce costs, and speed up training.
✅ **Proven recipes -** Highly-optimized training/finetuning recipes tested at enterprise scale.
# Quick start
Install LitGPT
```
pip install 'litgpt[extra]'
```
Load and use any of the [20+ LLMs](#choose-from-20-llms):
```python
from litgpt import LLM
llm = LLM.load("microsoft/phi-2")
text = llm.generate("Fix the spelling: Every fall, the family goes to the mountains.")
print(text)
# Corrected Sentence: Every fall, the family goes to the mountains.
```
✅ Optimized for fast inference
✅ Quantization
✅ Runs on low-memory GPUs
✅ No layers of internal abstractions
✅ Optimized for production scale
Advanced install options
Install from source:
```bash
git clone https://github.com/Lightning-AI/litgpt
cd litgpt
# if using uv
uv sync --all-extras
# if using pip
pip install -e ".[extra,compiler,test]"
```
[Explore the full Python API docs](tutorials/python-api.md).
---
# Choose from 20+ LLMs
Every model is written from scratch to maximize performance and remove layers of abstraction:
| Model | Model size | Author | Reference |
|----|----|----|----|
| Llama 3, 3.1, 3.2, 3.3 | 1B, 3B, 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) |
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
| ... | ... | ... | ... |
See full list of 20+ LLMs
#### All models
| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Gemma 2 | 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) |
| Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) |
| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) |
| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) |
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743) |
| Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233) |
| Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) |
| Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/) |
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
| Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
**Tip**: You can list all available models by running the `litgpt download list` command.
---
# Workflows
Finetune •
Pretrain •
Continued pretraining •
Evaluate •
Deploy •
Test
Use the command line interface to run advanced workflows such as pretraining or finetuning on your own data.
## All workflows
After installing LitGPT, select the model and workflow to run (finetune, pretrain, evaluate, deploy, etc...):
```bash
# litgpt [action] [model]
litgpt serve meta-llama/Llama-3.2-3B-Instruct
litgpt finetune meta-llama/Llama-3.2-3B-Instruct
litgpt pretrain meta-llama/Llama-3.2-3B-Instruct
litgpt chat meta-llama/Llama-3.2-3B-Instruct
litgpt evaluate meta-llama/Llama-3.2-3B-Instruct
```
----
## Finetune an LLM
Finetuning is the process of taking a pretrained AI model and further training it on a smaller, specialized dataset tailored to a specific task or application.
```bash
# 0) setup your dataset
curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json
# 1) Finetune a model (auto downloads weights)
litgpt finetune microsoft/phi-2 \
--data JSON \
--data.json_path my_custom_dataset.json \
--data.val_split_fraction 0.1 \
--out_dir out/custom-model
# 2) Test the model
litgpt chat out/custom-model/final
# 3) Deploy the model
litgpt serve out/custom-model/final
```
[Read the full finetuning docs](tutorials/finetune.md)
----
## Deploy an LLM
Deploy a pretrained or finetune LLM to use it in real-world applications. Deploy, automatically sets up a web server that can be accessed by a website or app.
```bash
# deploy an out-of-the-box LLM
litgpt serve microsoft/phi-2
# deploy your own trained model
litgpt serve path/to/microsoft/phi-2/checkpoint
```
Show code to query server:
Test the server in a separate terminal and integrate the model API into your AI product:
```python
# 3) Use the server (in a separate Python session)
import requests, json
response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Example input"}
)
print(response.json()["output"])
```
[Read the full deploy docs](tutorials/deploy.md).
----
## Evaluate an LLM
Evaluate an LLM to test its performance on various tasks to see how well it understands and generates text. Simply put, we can evaluate things like how well would it do in college-level chemistry, coding, etc... (MMLU, Truthful QA, etc...)
```bash
litgpt evaluate microsoft/phi-2 --tasks 'truthfulqa_mc2,mmlu'
```
[Read the full evaluation docs](tutorials/evaluation.md).
----
## Test an LLM
Test how well the model works via an interactive chat. Use the `chat` command to chat, extract embeddings, etc...
Here's an example showing how to use the Phi-2 LLM:
```bash
litgpt chat microsoft/phi-2
>> Prompt: What do Llamas eat?
```
Full code:
```bash
# 1) List all supported LLMs
litgpt download list
# 2) Use a model (auto downloads weights)
litgpt chat microsoft/phi-2
>> Prompt: What do Llamas eat?
```
The download of certain models requires an additional access token. You can read more about this in the [download](tutorials/download_model_weights.md#specific-models-and-access-tokens) documentation.
[Read the full chat docs](tutorials/inference.md).
----
## Pretrain an LLM
Pretraining is the process of teaching an AI model by exposing it to a large amount of data before it is fine-tuned for specific tasks.
Show code:
```bash
mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt
# 1) Download a tokenizer
litgpt download EleutherAI/pythia-160m \
--tokenizer_only True
# 2) Pretrain the model
litgpt pretrain EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--data TextFiles \
--data.train_data_path "custom_texts/" \
--train.max_tokens 10_000_000 \
--out_dir out/custom-model
# 3) Test the model
litgpt chat out/custom-model/final
```
[Read the full pretraining docs](tutorials/pretrain.md)
----
## Continue pretraining an LLM
Continued pretraining is another way of finetuning that specializes an already pretrained model by training on custom data:
Show code:
```bash
mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt
# 1) Continue pretraining a model (auto downloads weights)
litgpt pretrain EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--initial_checkpoint_dir EleutherAI/pythia-160m \
--data TextFiles \
--data.train_data_path "custom_texts/" \
--train.max_tokens 10_000_000 \
--out_dir out/custom-model
# 2) Test the model
litgpt chat out/custom-model/final
```
[Read the full continued pretraining docs](tutorials/pretrain.md#continued-pretraining-on-custom-data)
----
# State-of-the-art features
✅ State-of-the-art optimizations: Flash Attention v2, multi-GPU support via fully-sharded data parallelism, [optional CPU offloading](tutorials/oom.md#do-sharding-across-multiple-gpus), and [TPU and XLA support](extensions/xla).
✅ [Pretrain](tutorials/pretrain.md), [finetune](tutorials/finetune.md), and [deploy](tutorials/inference.md)
✅ Reduce compute requirements with low-precision settings: FP16, BF16, and FP16/FP32 mixed.
✅ Lower memory requirements with [quantization](tutorials/quantize.md): 4-bit floats, 8-bit integers, and double quantization.
✅ [Configuration files](config_hub) for great out-of-the-box performance.
✅ Parameter-efficient finetuning: [LoRA](tutorials/finetune_lora.md), [QLoRA](tutorials/finetune_lora.md), [Adapter](tutorials/finetune_adapter.md), and [Adapter v2](tutorials/finetune_adapter.md).
✅ [Exporting](tutorials/convert_lit_models.md) to other popular model weight formats.
✅ Many popular datasets for [pretraining](tutorials/pretrain.md) and [finetuning](tutorials/prepare_dataset.md), and [support for custom datasets](tutorials/prepare_dataset.md#preparing-custom-datasets-for-instruction-finetuning).
✅ Readable and easy-to-modify code to experiment with the latest research ideas.
---
# Training recipes
LitGPT comes with validated recipes (YAML configs) to train models under different conditions. We've generated these recipes based on the parameters we found to perform the best for different training conditions.
Browse all training recipes [here](config_hub).
### Example
```bash
litgpt finetune \
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml
```
✅ Use configs to customize training
Configs let you customize training for all granular parameters like:
```yaml
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama2-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
...
```
✅ Example: LoRA finetuning config
```yaml
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama2-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: float, default: 0.0003)
learning_rate: 0.0002
# (type: float, default: 0.02)
weight_decay: 0.0
# (type: float, default: 0.9)
beta1: 0.9
# (type: float, default: 0.95)
beta2: 0.95
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
```
✅ Override any parameter in the CLI:
```bash
litgpt finetune \
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \
--lora_r 4
```
----
# Project highlights
LitGPT powers many great AI projects, initiatives, challenges and of course enterprises. Please submit a pull request to be considered for a feature.
📊 SAMBA: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling
The [Samba](https://github.com/microsoft/Samba) project by researchers at Microsoft is built on top of the LitGPT code base and combines state space models with sliding window attention, which outperforms pure state space models.
🏆 NeurIPS 2023 Large Language Model Efficiency Challenge: 1 LLM + 1 GPU + 1 Day
The LitGPT repository was the official starter kit for the [NeurIPS 2023 LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io), which is a competition focused on finetuning an existing non-instruction tuned LLM for 24 hours on a single GPU.
🦙 TinyLlama: An Open-Source Small Language Model
LitGPT powered the [TinyLlama project](https://github.com/jzhang38/TinyLlama) and [TinyLlama: An Open-Source Small Language Model](https://arxiv.org/abs/2401.02385) research paper.
🍪 MicroLlama: MicroLlama-300M
[MicroLlama](https://github.com/keeeeenw/MicroLlama) is a 300M Llama model pretrained on 50B tokens powered by TinyLlama and LitGPT.
🔬 Pre-training Small Base LMs with Fewer Tokens
The research paper ["Pre-training Small Base LMs with Fewer Tokens"](https://arxiv.org/abs/2404.08634), which utilizes LitGPT, develops smaller base language models by inheriting a few transformer blocks from larger models and training on a tiny fraction of the data used by the larger models. It demonstrates that these smaller models can perform comparably to larger models despite using significantly less training data and resources.
----
# Community
We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.
- [Request a feature](https://github.com/Lightning-AI/litgpt/issues)
- [Submit your first contribution](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/)
- [Join our Discord](https://discord.gg/VptPCZkGNa)
# Tutorials
🚀 [Get started](tutorials/0_to_litgpt.md)
⚡️ [Finetuning, incl. LoRA, QLoRA, and Adapters](tutorials/finetune.md)
🤖 [Pretraining](tutorials/pretrain.md)
💬 [Model evaluation](tutorials/evaluation.md)
📘 [Supported and custom datasets](tutorials/prepare_dataset.md)
🧹 [Quantization](tutorials/quantize.md)
🤯 [Tips for dealing with out-of-memory (OOM) errors](tutorials/oom.md)
🧑🏽💻 [Using cloud TPUs](extensions/xla)
----
### Acknowledgments
This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) ⚡**.
- [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)
- [@EleutherAI](https://github.com/EleutherAI) for [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) and the [Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
- [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)
- [@tridao](https://github.com/tridao) for [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
### License
LitGPT is released under the [Apache 2.0](https://github.com/Lightning-AI/litgpt/blob/main/LICENSE) license.
### Citation
If you use LitGPT in your research, please cite the following work:
```bibtex
@misc{litgpt-2023,
author = {Lightning AI},
title = {LitGPT},
howpublished = {\url{https://github.com/Lightning-AI/litgpt}},
year = {2023},
}
```
================================================
FILE: config_hub/finetune/README.md
================================================
## Config files
The table below lists the performances you can expect from the provided config files. Note that you can achieve lower memory consumption by lowering the micro batch size as needed. In addition, you can lower the rank (`lora_r`) in the LoRA configuration files and disable LoRA for certain layers (for example, setting `lora_projection` and other LoRA layer-specific parameters to `false`).
For more information, see the [Dealing with out-of-memory (OOM) errors](../../tutorials/oom.md) on lowering the memory requirements.
The "Cost" column refers to the on-demand compute cost on [Lightning AI Studios where these benchmarks were executed](https://lightning.ai/lightning-ai/studios/automated-benchmarks-for-litgpt).
All experiments were conducted using bfloat-16 precision on the Alpaca2k dataset. The "Multitask score" refers to [MMLU](https://arxiv.org/abs/2009.03300).
| Config | Model | Epochs | Max seq length | Micro batch size | Machine | Training runtime | Cost | Peak memory | Validation loss | Validation perplexity | Multitask score (MMLU) |
| --------------------------------- | ---------------------- | ------ | -------------- | ---------------- | ------- | ---------------- | ---- | ----------- | --------------- | --------------------- | --------------- |
| falcon-7b/lora.yaml | falcon-7b | 4 | 512 | 1 | 1xA10G | 24.84 min | $0.7 | 16.69 GB | 0.945 | 2.573 | 26.2% |
| falcon-7b/lora.yaml | falcon-7b | 4 | 512 | 1 | 4xA10G | 24.94 min | $2.0 | 16.69 GB | 0.945 | 2.573 | 26.4% |
| falcon-7b/qlora.yaml | falcon-7b | 4 | 512 | 1 | 1xA10G | 50.85 min | $1.5 | 9.44 GB | 0.993 | 2.699 | 26.3% |
| | | | | | | | | | | | |
| gemma-2b/full.yaml | gemma-2b | 1 | 512 | 1 | 4xA10G | 14.06 min | $1.1 | 17.43 GB | 1.021 | 2.777 | 32.4% |
| gemma-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 9.41 min | $0.3 | 12.62 GB | 0.981 | 2.666 | 34.4% |
| gemma-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 4xA10G | 9.41 min | $0.8 | 12.62 GB | 0.981 | 2.667 | 34.0% |
| gemma-2b/qlora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 12.91 min | $0.4 | 11.58 GB | 1.085 | 2.959 | 36.4% |
| | | | | | | | | | | | |
| gemma-7b/lora.yaml | gemma-7b | 2 | 512 | 1 | 1xA10G | OOM | OOM | OOM | OOM | OOM | |
| gemma-7b/lora.yaml | gemma-7b | 2 | 512 | 1 | 4xA10G | OOM | OOM | OOM | OOM | OOM | |
| gemma-7b/qlora.yaml | gemma-7b | 2 | 512 | 1 | 1xA10G | 43.58 min | $1.3 | 17.18 GB | 0.973 | 2.646 | 62.45% |
| | | | | | | | | | | | |
| gemma2-2b/lora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 11.96 min | $0.4 | 14.31 GB | 0.951 | 2.589 | 23.84% |
| gemma2b/qlora.yaml | gemma-2b | 2 | 512 | 2 | 1xA10G | 16.06 min | $0.5 | 13.52 GB | 0.983 | 2.673 | 24.12% |
| | | | | | | | | | | | |
| gemma2-9b/lora.yaml | gemma-2-9b | 2 | 512 | 1 | 1xA10G | OOM | OOM | OOM | OOM | OOM | |
| gemma2-9b/lora.yaml | gemma-2-9b | 2 | 512 | 1 | 4xA10G | OOM | OOM | OOM | OOM | OOM | |
| gemma2-9b/qlora.yaml | gemma-2-9b | 2 | 512 | 1 | 1xA10G | 50.01 min | $4.0 | 20.92 GB | 0.852 | 2.345 | 24.2% |
| | | | | | | | | | | | |
| llama-2-7b/full.yaml | llama-2-7b | 1 | 512 | 4 | 4xA10G | OOM | OOM | OOM | OOM | OOM | |
| llama-2-7b/lora.yaml | llama-2-7b | 4 | 512 | 2 | 1xA10G | 32.82 min | $1.0 | 19.77 GB | 0.802 | 2.230 | 40.3% |
| llama-2-7b/lora.yaml | llama-2-7b | 4 | 512 | 2 | 4xA10G | 32.83 min | $2.6 | 19.77 GB | 0.802 | 2.229 | 40.2% |
| llama-2-7b/qlora.yaml | llama-2-7b | 4 | 512 | 2 | 1xA10G | 45.67 min | $1.4 | 13.68 GB | 0.814 | 2.258 | 38.6% |
| | | | | | | | | | | | |
| llama-3-8b/full.yaml | llama-3-8b | 1 | 512 | 4 | 4xA10G | OOM | OOM | OOM | OOM | OOM | |
| llama-3-8b/lora.yaml | llama-3-8b | 2 | 512 | 1 | 1xA10G | 14.79 min | $0.4 | 19.73 GB | 0.888 | 2.431 | 62.4% |
| llama-3-8b/lora.yaml | llama-3-8b | 2 | 512 | 1 | 4xA10G | 14.88 min | $1.2 | 19.73 GB | 0.889 | 2.432 | 62.5% |
| llama-3-8b/qlora.yaml | llama-3-8b | 2 | 512 | 2 | 1xA10G | 22.24 min | $0.7 | 17.41 GB | 0.939 | 2.558 | 62.2% |
| | | | | | | | | | | | |
| llama-3.1-8b/full.yaml | llama-3.1-8b | 1 | 512 | 4 | 1xA10G | OOM | OOM | OOM | OOM | OOM | OOM |
| llama-3.1-8b/lora.yaml | llama-3.1-8b | 2 | 512 | 1 | 1xA10G | 13.36 min | $1.1 | 19.73 GB | 0.878 | 2.406 | xx.xx |
| llama-3.1-8b/qlora.yaml | llama-3.1-8b | 2 | 512 | 2 | 1xA10G | 21.81 min | $0.7 | 17.41 GB | 0.928 | 2.529 | xx.xx |
| | | | | | | | | | | | |
| llama-3.2-1b/full.yaml | llama-3.2-1b | 1 | 512 | 4 | 1xA10G | 2.01 min | $0.1 | 8.70 GB | 1.442 | 4.229 | 38.21% |
| llama-3.2-1b/lora.yaml | llama-3.2-1b | 2 | 512 | 1 | 1xA10G | 4.17 min | $0.4 | 4.49 GB | 1.114 | 3.046 | 36.87% |
| llama-3.2-1b/qlora.yaml | llama-3.2-1b | 2 | 512 | 2 | 1xA10G | 6.20 min | $0.6 | 5.53 GB | 1.201 | 3.322 | 36.49% |
| | | | | | | | | | | | |
| llama-3.2-3b/full.yaml | llama-3.2-3b | 1 | 512 | 4 | 1xA10G | 4.71 min | $0.4 | 16.51 GB | 1.255 | 3.509 | 54.69% |
| llama-3.2-3b/lora.yaml | llama-3.2-3b | 2 | 512 | 1 | 1xA10G | 8.31 min | $0.8 | 9.67 GB | 0.973 | 2.647 | 54.77% |
| llama-3.2-3b/qlora.yaml | llama-3.2-3b | 2 | 512 | 2 | 1xA10G | 14.89 min | $1.4 | 10.30 GB | 1.031 | 2.804 | 55.08% |
| | | | | | | | | | | | |
| mistral-7b-v0.2/lora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 1xA10G | 31.00 min | $0.9 | 20.66 GB | 0.801 | 2.228 | 55.7% |
| mistral-7b-v0.2/lora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 4xA10G | 31.00 min | $2.5 | 20.66 GB | 0.802 | 2.229 | 55.5% |
| mistral-7b-v0.2/qlora.yaml | mistral-7b-v0.2 | 4 | 512 | 2 | 1xA10G | 44.75 min | $1.3 | 14.29 GB | 0.813 | 2.255 | 56.5% |
| | | | | | | | | | | | |
| mistral-7b/lora.yaml | mistral-7b | 4 | 512 | 2 | 1xA10G | 31.01 min | $0.9 | 20.66 GB | 0.794 | 2.211 | 57.9% |
| mistral-7b/lora.yaml | mistral-7b | 4 | 512 | 2 | 4xA10G | 31.03 min | $2.5 | 20.66 GB | 0.796 | 2.218 | 57.9% |
| mistral-7b/qlora.yaml | mistral-7b | 4 | 512 | 2 | 1xA10G | 44.75 min | $1.3 | 14.29 GB | 0.803 | 2.231 | 57.9% |
| | | | | | | | | | | | |
| phi-2/full.yaml | phi-2 | 1 | 512 | 4 | 4xA10G | 11.87 min | $1.0 | 14.44 GB | 1.305 | 3.688 | 38.4% |
| phi-2/lora.yaml | phi-2 | 1 | 512 | 4 | 1xA10G | 3.78 min | $0.1 | 13.98 GB | 0.819 | 2.269 | 53.0% |
| phi-2/lora.yaml | phi-2 | 1 | 512 | 4 | 4xA10G | 3.78 min | $0.3 | 13.98 GB | 0.820 | 2.271 | 52.4% |
| phi-2/qlora.yaml | phi-2 | 1 | 512 | 4 | 1xA10G | 4.51 min | $0.1 | 14.27 GB | 0.837 | 2.310 | 52.3% |
| | | | | | | | | | | | |
| phi-3/full.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 6.93 min | $0.2 | 17.01 GB | 0.714 | 2.043 | 69.81% |
| phi-3/lora.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 6.46 min | $0.2 | 19.75 GB | 0.707 | 2.028 | 69.70% |
| phi-3/qlora.yaml | Phi-3-mini-4k-instruct | 1 | 512 | 4 | 1xA10G | 7.47 min | $0.2 | 19.13 GB | 0.729 | 2.074 | 68.96% |
| | | | | | | | | | | | |
| stablelm-base-alpha-3b/full.yaml | stablelm-base-alpha-3b | 1 | 512 | 1 | 4xA10G | 70.13 min | $5.6 | 21.23 GB | 1.513 | 4.540 | 23.2% |
| stablelm-base-alpha-3b/lora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 1xA10G | 13.07 min | $0.4 | 8.58 GB | 1.361 | 3.900 | 25.9% |
| stablelm-base-alpha-3b/lora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 4xA10G | 13.16 min | $1.1 | 8.58 GB | 1.362 | 3.906 | 25.9% |
| stablelm-base-alpha-3b/qlora.yaml | stablelm-base-alpha-3b | 4 | 512 | 1 | 1xA10G | 25.86 min | $0.8 | 5.24 GB | 1.388 | 4.009 | 26.1% |
| | | | | | | | | | | | |
| tiny-llama/full.yaml | tiny-llama | 1 | 512 | 4 | 1xA10G | 2.58 min | $0.1 | 14.10 GB | 1.088 | 2.968 | 24.6% |
| tiny-llama/full.yaml | tiny-llama | 1 | 512 | 4 | 4xA10G | 2.57 min | $0.2 | 14.10 GB | 1.088 | 2.968 | 24.5% |
| tiny-llama/lora.yaml | tiny-llama | 3 | 512 | 8 | 1xA10G | 8.09 min | $0.2 | 13.50 GB | 1.039 | 2.826 | 25.5% |
| tiny-llama/qlora.yaml | tiny-llama | 3 | 512 | 8 | 1xA10G | 8.70 min | $0.3 | 16.24 GB | 1.056 | 2.874 | 25.3% |
*OOM = Out of memory
## Extending the context length
If you require a longer sequence length than the one used in a given config file, you can either edit the `max_seq_length` in the config file or pass an additional argument when running the finetuning command, for example, `--max_seq_length 4096` to override the sequence length provided in the config file.
## Training on GPUs without bfloat16 support
If you are training on GPUs without bfloat-16 support, you need to change the `precision` option to `16-true` (16-bit floating point precision) or `16-mixed` (16/32-bit mixed precision) training:
```bash
litgpt finetune lora \
--config config_hub/finetune/phi-2/lora.yaml \
--precision 16-true
```
or
```bash
litgpt finetune lora \
--config config_hub/finetune/phi-2/lora.yaml \
--precision 16-mixed
```
Note that `16-true` is more compute and memory-efficient, but it can sometimes lead to training convergence issues. In this case, it's recommended to use `16-mixed`.
## Multi-GPU experiments
All runs are single-GPU experiments, use `--devices 4` to utilize more than one GPU:
```bash
litgpt finetune lora \
--config config_hub/finetune/phi-2/lora.yaml \
--devices 4
```
================================================
FILE: config_hub/finetune/falcon-7b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/tiiuae/falcon-7b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-falcon-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/falcon-7b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/tiiuae/falcon-7b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-falcon-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma-2b/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/full-gemma-2b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 4
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 16
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 100
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps: 50
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma-2b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-gemma-2b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma-2b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-gemma-2b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma-7b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-7b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-gemma-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma-7b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-7b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-gemma-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma2-2b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2-2b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-gemma-2-2b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma2-2b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2-2b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-gemma-2-2b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma2-9b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2-9b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-gemma-2-9b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/gemma2-9b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/google/gemma-2-9b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-gemma-2-9b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 16
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.1
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 6
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-2-7b/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-llama2-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 4
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 64
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 25
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-2-7b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-llama2-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-2-7b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama2-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3-8b/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-llama-3-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 4
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 64
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 25
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3-8b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-llama-3-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3-8b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama3-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.1-8b/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-llama-3.1-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 4
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 64
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 25
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.1-8b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-llama-3.1-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.1-8b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama3.1-8b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-1B/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-llama-3.2-1B
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
# resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 64
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 25
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-1B/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-llama-3.2-1B
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-1B/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-1B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama3.2-1b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-3B/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-llama-3.2-3B
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
# resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 64
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 25
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-3B/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-llama-3.2-3B
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/llama-3.2-3B/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/meta-llama/Llama-3.2-3B
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-llama3.2-3b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 2
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/mistral-7b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-mistral-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/mistral-7b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-mistral-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/mistral-7b-v0.2/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-mistral-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/mistral-7b-v0.2/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-mistral-7b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 2
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-2/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/phi-2
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-phi-2
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 2
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps: 100
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-2/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/phi-2
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-phi-2
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-2/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/phi-2
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-phi-2
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-3/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct
# Directory in which to save checkpoints and logs. (type: , default: out/finetune/full)
out_dir: out/finetune/full-phi-3
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use (type: Union[int, str], default: 1)
devices: 1
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 1)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 200
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 600)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-3/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-phi-3
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/phi-3/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-phi-3
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 8
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/stablelm-base-alpha-3b/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/full-stablelm-base-alpha-3b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 2
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 1000
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/stablelm-base-alpha-3b/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-stablelm-base-alpha-3b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/stablelm-base-alpha-3b/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-stablelm-base-alpha-3b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: false
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: false
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: false
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: false
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.05
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
download_dir: data/alpaca2k
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 200
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 1
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 4
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run (type: Optional[int], default: null)
max_steps:
# Limits the length of samples (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/tiny-llama/full.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/full-tiny-llama-1.1b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 32
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 1000
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 1
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 25
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/tiny-llama/lora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/lora-tiny-llama-1.1b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize:
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 8
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 3
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/finetune/tiny-llama/qlora.yaml
================================================
# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b)
checkpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
# Directory in which to save checkpoints and logs. (type: , default: out/lora)
out_dir: out/finetune/qlora-tiny-llama-1.1b
# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-true
# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)
quantize: bnb.nf4
# How many devices/GPUs to use. (type: Union[int, str], default: 1)
devices: 1
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# The LoRA rank. (type: int, default: 8)
lora_r: 32
# The LoRA alpha. (type: int, default: 16)
lora_alpha: 16
# The LoRA dropout value. (type: float, default: 0.05)
lora_dropout: 0.05
# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)
lora_query: true
# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)
lora_key: true
# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)
lora_value: true
# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)
lora_projection: true
# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)
lora_mlp: true
# Whether to apply LoRA to output head in GPT. (type: bool, default: False)
lora_head: true
# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data:
class_path: litgpt.data.Alpaca2k
init_args:
mask_prompt: false
val_split_fraction: 0.03847
prompt_style: alpaca
ignore_index: -100
seed: 42
num_workers: 4
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 800
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)
global_batch_size: 8
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 8
# Number of iterations with learning rate warmup active (type: int, default: 100)
lr_warmup_steps: 10
# Number of epochs to train on (type: Optional[int], default: 5)
epochs: 3
# Total number of tokens to train on (type: Optional[int], default: null)
max_tokens:
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 512
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)
tie_embeddings:
# (type: Optional[float], default: null)
max_norm:
# (type: float, default: 6e-05)
min_lr: 6.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 100)
interval: 100
# Number of tokens to generate (type: Optional[int], default: 100)
max_new_tokens: 100
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: true
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 1337)
seed: 1337
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0002
# (type: float, default: 0.01)
weight_decay: 0.0
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
================================================
FILE: config_hub/pretrain/debug.yaml
================================================
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: pythia-14m
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs//share. (type: , default: out/pretrain)
out_dir: out/pretrain/debug
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed
# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: TinyStories
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
global_batch_size: 125
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 5
# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 100
# Number of epochs to train on (type: Optional[int], default: null)
epochs:
# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 100000000
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length:
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:
# (type: Optional[float], default: 1.0)
max_norm: 1.0
# (type: float, default: 4e-05)
min_lr: 6e-5
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000
# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: false
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 6e-4
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/EleutherAI/pythia-14m
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)
logger_name: tensorboard
# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
================================================
FILE: config_hub/pretrain/microllama.yaml
================================================
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: micro-llama-300M
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs//share. (type: , default: out/pretrain)
out_dir: out/pretrain/micro-llama
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed
# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: MicroLlama
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
# Scale this number according to the number of GPU and memory size per GPU
# For example, we used 48 for 4 x 24G 4090
global_batch_size: 48
# Number of samples per data-parallel rank (type: int, default: 12)
# Scale this number according to the memory size per GPU
# For example, we used 12 for 24G 4090
micro_batch_size: 12
# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000
# Number of epochs to train on (type: Optional[int], default: null)
epochs:
# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:
# (type: Optional[float], default: 1.0)
max_norm: 1.0
# (type: float, default: 4e-05)
min_lr: 4.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000
# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 4e-4
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)
logger_name: tensorboard
# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
================================================
FILE: config_hub/pretrain/tinyllama.yaml
================================================
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: tiny-llama-1.1b
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs//share. (type: , default: out/pretrain)
out_dir: out/pretrain/tiny-llama
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed
# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: TinyLlama
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
global_batch_size: 512
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 4
# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000
# Number of epochs to train on (type: Optional[int], default: null)
epochs:
# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:
# (type: Optional[float], default: 1.0)
max_norm: 1.0
# (type: float, default: 4e-05)
min_lr: 4.0e-05
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000
# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: false
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 4e-4
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)
logger_name: tensorboard
# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
================================================
FILE: config_hub/pretrain/tinystories.yaml
================================================
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: stories15M
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:
name: stories15M
hf_config: {}
scale_embeddings: false
block_size: 256
padded_vocab_size: 32000
n_layer: 6
n_head: 6
n_query_groups: 6
n_embd: 288
head_size: 48
rotary_percentage: 1.0
parallel_residual: false
bias: false
norm_class_name: RMSNorm
mlp_class_name: LLaMAMLP
intermediate_size: 768
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs//share. (type: , default: out/pretrain)
out_dir: out/pretrain/stories15M
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed
# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
# (type: Union[bool, Literal["auto"], Path], default: False)
resume: false
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: TinyStories
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000
# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
global_batch_size: 512
# Number of samples per data-parallel rank (type: int, default: 4)
micro_batch_size: 128
# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 1000
# Number of epochs to train on (type: Optional[int], default: null)
epochs:
# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 9700000000 # original did 298,000 iters
# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:
# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 256
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings: true
# (type: Optional[float], default: 1.0)
max_norm: 1.0
# (type: float, default: 4e-05)
min_lr: 0.0
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:
# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 2000
# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:
# Number of iterations (type: int, default: 100)
max_iters: 100
# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false
# Whether to evaluate on the validation set at the end the training
final_validation: false
# Optimizer-related arguments
optimizer:
class_path: torch.optim.AdamW
init_args:
# (type: float, default: 0.001)
lr: 0.0005
# (type: float, default: 0.01)
weight_decay: 0.1
# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto
# How many nodes to use. (type: int, default: 1)
num_nodes: 1
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)
logger_name: csv
# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
================================================
FILE: extensions/thunder/README.md
================================================
# Lightning Thunder: a source-to-source compiler for PyTorch
[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) makes PyTorch programs faster both on single accelerators or in distributed settings.
Thunder aims to be usable, understandable, and extensible and can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors.
This extension directory shows how Thunder can be used with LitGPT.
> [!WARNING]
> This document is an early-access development version that is currently only for internal use. We recommend users checking out the [Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) project directly, which provides more up-to-date usage information.
## Thunder 👉👈 LitGPT: a short showcase
To try Lightning Thunder with your model simply `thunder.jit()` it.
```python
from litgpt import GPT
import thunder
import torch
# Use only two layers to keep the traces shorter for the demonstration
model = GPT.from_name("Llama-2-7b-hf", n_layer=2).cuda()
model = thunder.jit(model)
x = torch.randint(model.max_seq_length, (2, 5), device="cuda")
y = model(x) # forward, this may take a bit
```
This will require some compilation time on the first forward call.
### Traces
The JIT is will acquire a Python program (what we call a "trace") from the Python program (`GPT`, a `torch.nn.Module` in this example) that was given.
This process targets PyTorch operators (like `Tensor.view()`, `+`, `torch.nn.functional.scaled_dot_product_atttention()`) and optionally custom operators (more about that later).
We can visualize the thunder trace generated under the hood:
```python
forward_trace = thunder.last_traces(model)[-1].python()
print(forward_trace)
```
```python
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
# args: "Collection"
t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \
t18, t19, = args
del args
t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False) # t24: "cuda:0 f32[2, 5, 4096]"
t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1]) # t20: "cuda:0 f32[5, 128]"
t21 = torch_slice_prim_impl(t2, [0, 0], [5, 128], [1, 1]) # t21: "cuda:0 f32[5, 128]"
t200 = torch.unsqueeze(t11, 0) # t200: "cuda:0 f32[1, 4096]"
t201 = torch.unsqueeze(t200, 1) # t201: "cuda:0 f32[1, 1, 4096]"
del t200
t33 = Tensor.expand(t201, (2, 5, 4096)) # t33: "cuda:0 f32[2, 5, 4096]"
del t201
t229 = torch.unsqueeze(t13, 0) # t229: "cuda:0 f32[1, 4096]"
t230 = torch.unsqueeze(t229, 1) # t230: "cuda:0 f32[1, 1, 4096]"
del t229
t84 = Tensor.expand(t230, (2, 5, 4096)) # t84: "cuda:0 f32[2, 5, 4096]"
del t230
t232 = torch.unsqueeze(t12, 0) # t232: "cuda:0 f32[1, 4096]"
t233 = torch.unsqueeze(t232, 1) # t233: "cuda:0 f32[1, 1, 4096]"
del t232
t104 = Tensor.expand(t233, (2, 5, 4096)) # t104: "cuda:0 f32[2, 5, 4096]"
del t233
t253 = torch.unsqueeze(t14, 0) # t253: "cuda:0 f32[1, 4096]"
t254 = torch.unsqueeze(t253, 1) # t254: "cuda:0 f32[1, 1, 4096]"
del t253
t155 = Tensor.expand(t254, (2, 5, 4096)) # t155: "cuda:0 f32[2, 5, 4096]"
del t254
t256 = torch.unsqueeze(t10, 0) # t256: "cuda:0 f32[1, 4096]"
t257 = torch.unsqueeze(t256, 1) # t257: "cuda:0 f32[1, 1, 4096]"
del t256
t175 = Tensor.expand(t257, (2, 5, 4096)) # t175: "cuda:0 f32[2, 5, 4096]"
del t257
t221 = torch.unsqueeze(t20, 0) # t221: "cuda:0 f32[1, 5, 128]"
del t20
t222 = torch.unsqueeze(t221, 1) # t222: "cuda:0 f32[1, 1, 5, 128]"
del t221
t49 = Tensor.expand(t222, (2, 32, 5, 128)) # t49: "cuda:0 f32[2, 32, 5, 128]"
del t222
t224 = torch.unsqueeze(t21, 0) # t224: "cuda:0 f32[1, 5, 128]"
del t21
t225 = torch.unsqueeze(t224, 1) # t225: "cuda:0 f32[1, 1, 5, 128]"
del t224
t51 = Tensor.expand(t225, (2, 32, 5, 128)) # t51: "cuda:0 f32[2, 32, 5, 128]"
del t225
[t30, t34] = nvFusion0(t24, t33)
t35 = torch.nn.functional.linear(t34, t3, None) # t35: "cuda:0 f32[2, 5, 12288]"
t36 = torch.reshape(t35, (2, 5, 32, 3, 128)) # t36: "cuda:0 f32[2, 5, 32, 3, 128]"
del t35
t37 = torch.permute(t36, (0, 2, 3, 1, 4)) # t37: "cuda:0 f32[2, 32, 3, 5, 128]"
del t36
(t38, t39, t40) = torch.split(t37, (1, 1, 1), 2)
del t37
t41 = torch.reshape(t38, (2, 32, 5, 128)) # t41: "cuda:0 f32[2, 32, 5, 128]"
del t38
t42 = torch.reshape(t39, (2, 32, 5, 128)) # t42: "cuda:0 f32[2, 32, 5, 128]"
del t39
t43 = torch.reshape(t40, (2, 32, 5, 128)) # t43: "cuda:0 f32[2, 32, 5, 128]"
del t40
t44 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t44: "cuda:0 f32[2, 32, 5, 128]"
t54 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t54: "cuda:0 f32[2, 32, 5, 128]"
t64 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t64: "cuda:0 f32[2, 32, 5, 0]"
del t41
t66 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t66: "cuda:0 f32[2, 32, 5, 0]"
del t42
t46 = torch_slice_prim_impl(t44, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t46: "cuda:0 f32[2, 32, 5, 64]"
t45 = torch_slice_prim_impl(t44, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t45: "cuda:0 f32[2, 32, 5, 64]"
t55 = torch_slice_prim_impl(t54, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t55: "cuda:0 f32[2, 32, 5, 64]"
t56 = torch_slice_prim_impl(t54, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t56: "cuda:0 f32[2, 32, 5, 64]"
[t47, t57] = nvFusion1(t46, t56)
del t46, t56
t48 = torch.cat((t47, t45), -1) # t48: "cuda:0 f32[2, 32, 5, 128]"
del t47, t45
t58 = torch.cat((t57, t55), -1) # t58: "cuda:0 f32[2, 32, 5, 128]"
del t57, t55
[t53, t63] = nvFusion2(t44, t48, t49, t51, t54, t58)
del t44, t48, t54, t58
t65 = torch.cat((t53, t64), -1) # t65: "cuda:0 f32[2, 32, 5, 128]"
del t53, t64
t67 = torch.cat((t63, t66), -1) # t67: "cuda:0 f32[2, 32, 5, 128]"
del t63, t66
(t68, t69, t70, t71) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t65, t67, t43, None, 0.0, True, 0.08838834764831843)
t72 = torch.permute(t68, (0, 2, 1, 3)) # t72: "cuda:0 f32[2, 5, 32, 128]"
t73 = torch.reshape(t72, (2, 5, 4096)) # t73: "cuda:0 f32[2, 5, 4096]"
del t72
t74 = torch.nn.functional.linear(t73, t15, None) # t74: "cuda:0 f32[2, 5, 4096]"
[t75, t81, t85] = nvFusion3(t24, t74, t84)
del t74
t86 = torch.nn.functional.linear(t85, t5, None) # t86: "cuda:0 f32[2, 5, 11008]"
t87 = torch.nn.functional.linear(t85, t7, None) # t87: "cuda:0 f32[2, 5, 11008]"
[t93] = nvFusion4(t86, t87)
t94 = torch.nn.functional.linear(t93, t16, None) # t94: "cuda:0 f32[2, 5, 4096]"
[t101, t105, t95] = nvFusion5(t104, t75, t94)
del t94
t106 = torch.nn.functional.linear(t105, t4, None) # t106: "cuda:0 f32[2, 5, 12288]"
t107 = torch.reshape(t106, (2, 5, 32, 3, 128)) # t107: "cuda:0 f32[2, 5, 32, 3, 128]"
del t106
t108 = torch.permute(t107, (0, 2, 3, 1, 4)) # t108: "cuda:0 f32[2, 32, 3, 5, 128]"
del t107
(t109, t110, t111) = torch.split(t108, (1, 1, 1), 2)
del t108
t112 = torch.reshape(t109, (2, 32, 5, 128)) # t112: "cuda:0 f32[2, 32, 5, 128]"
del t109
t113 = torch.reshape(t110, (2, 32, 5, 128)) # t113: "cuda:0 f32[2, 32, 5, 128]"
del t110
t114 = torch.reshape(t111, (2, 32, 5, 128)) # t114: "cuda:0 f32[2, 32, 5, 128]"
del t111
t135 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t135: "cuda:0 f32[2, 32, 5, 0]"
t137 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1]) # t137: "cuda:0 f32[2, 32, 5, 0]"
t115 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t115: "cuda:0 f32[2, 32, 5, 128]"
del t112
t125 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t125: "cuda:0 f32[2, 32, 5, 128]"
del t113
t116 = torch_slice_prim_impl(t115, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t116: "cuda:0 f32[2, 32, 5, 64]"
t117 = torch_slice_prim_impl(t115, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t117: "cuda:0 f32[2, 32, 5, 64]"
t127 = torch_slice_prim_impl(t125, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1]) # t127: "cuda:0 f32[2, 32, 5, 64]"
t126 = torch_slice_prim_impl(t125, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1]) # t126: "cuda:0 f32[2, 32, 5, 64]"
[t118, t128] = nvFusion6(t117, t127)
del t117, t127
t129 = torch.cat((t128, t126), -1) # t129: "cuda:0 f32[2, 32, 5, 128]"
del t128, t126
t119 = torch.cat((t118, t116), -1) # t119: "cuda:0 f32[2, 32, 5, 128]"
del t118, t116
[t124, t134] = nvFusion7(t115, t119, t125, t129, t49, t51)
del t115, t119, t125, t129
t136 = torch.cat((t124, t135), -1) # t136: "cuda:0 f32[2, 32, 5, 128]"
del t124, t135
t138 = torch.cat((t134, t137), -1) # t138: "cuda:0 f32[2, 32, 5, 128]"
del t134, t137
(t139, t140, t141, t142) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t136, t138, t114, None, 0.0, True, 0.08838834764831843)
t143 = torch.permute(t139, (0, 2, 1, 3)) # t143: "cuda:0 f32[2, 5, 32, 128]"
t144 = torch.reshape(t143, (2, 5, 4096)) # t144: "cuda:0 f32[2, 5, 4096]"
del t143
t145 = torch.nn.functional.linear(t144, t17, None) # t145: "cuda:0 f32[2, 5, 4096]"
[t146, t152, t156] = nvFusion8(t145, t155, t95)
del t145
t158 = torch.nn.functional.linear(t156, t8, None) # t158: "cuda:0 f32[2, 5, 11008]"
t157 = torch.nn.functional.linear(t156, t6, None) # t157: "cuda:0 f32[2, 5, 11008]"
[t164] = nvFusion9(t157, t158)
t165 = torch.nn.functional.linear(t164, t18, None) # t165: "cuda:0 f32[2, 5, 4096]"
[t166, t172, t176] = nvFusion10(t146, t165, t175)
del t165
t177 = torch.nn.functional.linear(t176, t9, None) # t177: "cuda:0 f32[2, 5, 32000]"
return {'output': t177, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19], 'flat_output': (t177,)}, ((t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95), (False, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2))
```
This is a straight-lined version of `GPT.forward` that has been optimized. Since it's running on CUDA, the [NvFuser](https://github.com/NVIDIA/Fuser) executor has created regions (look for "nvFusion") that fuse multiple operators together.
Operator fusion is very desirable with modern hardware and helps out in overhead-bound or device-bound settings by:
- Launching less kernels, thus reducing the kernel launch overhead.
- Reducing the number of memory accesses performed by reusing them in a fused operation
- Minimizing host-device communications
Thunder also uses a multi-level intermediate representation. If we let it print all levels
```python
forward_trace = thunder.last_traces(model)[-1]
print(forward_trace)
```
We can see as comments the primitives that compose the fusion regions. For instance, this is the region associated to [the `RMSNorm` implementation](https://github.com/Lightning-AI/litgpt/blob/9b6475dabf90c7acee506a026bd9fa86251835bf/litgpt/model.py#L409-L420)
```python
[t146, t152, t156] = nvFusion8(t145, t155, t95)
# t146 = prims.add(t145, t95) # t146: "cuda:0 f32[2, 5, 4096]"
# t147 = prims.mul(t146, t146) # t147: "cuda:0 f32[2, 5, 4096]"
# t148 = prims.sum(t147, (2,)) # t148: "cuda:0 f32[2, 5]"
# t149 = prims.broadcast_in_dim(t148, [2, 5, 1], [0, 1]) # t149: "cuda:0 f32[2, 5, 1]"
# t150 = prims.div(t149, 4096.0) # t150: "cuda:0 f32[2, 5, 1]"
# t151 = prims.add(t150, 1e-05) # t151: "cuda:0 f32[2, 5, 1]"
# t152 = prims.rsqrt(t151) # t152: "cuda:0 f32[2, 5, 1]"
# t153 = prims.broadcast_in_dim(t152, (2, 5, 4096), (0, 1, 2)) # t153: "cuda:0 f32[2, 5, 4096]"
# t154 = prims.mul(t146, t153) # t154: "cuda:0 f32[2, 5, 4096]"
# t156 = prims.mul(t154, t155) # t156: "cuda:0 f32[2, 5, 4096]"
```
Similarly, we can visualize the backward trace:
```python
backward_trace = thunder.last_backward_traces(model)[-1].python()
print(backward_trace)
```
```python
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t178, = cotangents
clear_collection(cotangents)
del cotangents
t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \
t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \
t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \
t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0
clear_collection(C0)
del C0
b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \
= C1
clear_collection(C1)
del C1
t639 = torch.reshape(t178, (-1, 32000)) # t639: "cuda:0 f32[10, 32000]"
del t178
t643 = torch.permute(t639, (1, 0)) # t643: "cuda:0 f32[32000, 10]"
t644 = torch.reshape(t176, (-1, 4096)) # t644: "cuda:0 f32[10, 4096]"
del t176
t669 = torch.reshape(t164, (-1, 11008)) # t669: "cuda:0 f32[10, 11008]"
del t164
t686 = torch.reshape(t156, (-1, 4096)) # t686: "cuda:0 f32[10, 4096]"
del t156
t720 = torch.reshape(t144, (-1, 4096)) # t720: "cuda:0 f32[10, 4096]"
del t144
t776 = torch.reshape(t105, (-1, 4096)) # t776: "cuda:0 f32[10, 4096]"
del t105
t802 = torch.reshape(t93, (-1, 11008)) # t802: "cuda:0 f32[10, 11008]"
del t93
t819 = torch.reshape(t85, (-1, 4096)) # t819: "cuda:0 f32[10, 4096]"
del t85
t853 = torch.reshape(t73, (-1, 4096)) # t853: "cuda:0 f32[10, 4096]"
del t73
t911 = torch.reshape(t34, (-1, 4096)) # t911: "cuda:0 f32[10, 4096]"
del t34
t640 = torch.matmul(t639, t9) # t640: "cuda:0 f32[10, 4096]"
del t639, t9
t645 = torch.matmul(t643, t644) # t645: "cuda:0 f32[32000, 4096]"
del t643, t644
t641 = torch.reshape(t640, (2, 5, 4096)) # t641: "cuda:0 f32[2, 5, 4096]"
del t640
[t648, t663] = nvFusion0(f106, t166, t172, t175, t641)
del f106, t166, t172, t175, t641
t664 = torch.reshape(t663, (-1, 4096)) # t664: "cuda:0 f32[10, 4096]"
t668 = torch.permute(t664, (1, 0)) # t668: "cuda:0 f32[4096, 10]"
t665 = torch.matmul(t664, t18) # t665: "cuda:0 f32[10, 11008]"
del t664, t18
t670 = torch.matmul(t668, t669) # t670: "cuda:0 f32[4096, 11008]"
del t668, t669
t666 = torch.reshape(t665, (2, 5, 11008)) # t666: "cuda:0 f32[2, 5, 11008]"
del t665
[t672, t680] = nvFusion1(t157, t158, t666)
del t157, t158, t666
t681 = torch.reshape(t672, (-1, 11008)) # t681: "cuda:0 f32[10, 11008]"
del t672
t685 = torch.permute(t681, (1, 0)) # t685: "cuda:0 f32[11008, 10]"
t688 = torch.reshape(t680, (-1, 11008)) # t688: "cuda:0 f32[10, 11008]"
del t680
t692 = torch.permute(t688, (1, 0)) # t692: "cuda:0 f32[11008, 10]"
t689 = torch.matmul(t688, t6) # t689: "cuda:0 f32[10, 4096]"
del t688, t6
t682 = torch.matmul(t681, t8) # t682: "cuda:0 f32[10, 4096]"
del t681, t8
t694 = torch.matmul(t692, t686) # t694: "cuda:0 f32[11008, 4096]"
del t692
t687 = torch.matmul(t685, t686) # t687: "cuda:0 f32[11008, 4096]"
del t685, t686
t683 = torch.reshape(t682, (2, 5, 4096)) # t683: "cuda:0 f32[2, 5, 4096]"
del t682
t690 = torch.reshape(t689, (2, 5, 4096)) # t690: "cuda:0 f32[2, 5, 4096]"
del t689
[t698, t714] = nvFusion2(f101, t146, t152, t155, t663, t683, t690)
del f101, t146, t152, t155, t663, t683, t690
t715 = torch.reshape(t714, (-1, 4096)) # t715: "cuda:0 f32[10, 4096]"
t719 = torch.permute(t715, (1, 0)) # t719: "cuda:0 f32[4096, 10]"
t716 = torch.matmul(t715, t17) # t716: "cuda:0 f32[10, 4096]"
del t715, t17
t721 = torch.matmul(t719, t720) # t721: "cuda:0 f32[4096, 4096]"
del t719, t720
t717 = torch.reshape(t716, (2, 5, 4096)) # t717: "cuda:0 f32[2, 5, 4096]"
del t716
t722 = torch.reshape(t717, (2, 5, 32, 128)) # t722: "cuda:0 f32[2, 5, 32, 128]"
del t717
t723 = torch.permute(t722, (0, 2, 1, 3)) # t723: "cuda:0 f32[2, 32, 5, 128]"
del t722
(t724, t725, t726, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t723, t136, t138, t114, None, t139, t140, t141, t142, f90, b91, scale=f92)
del t723, t136, t138, t114, t139, t140, t141, t142, f90, b91, f92
t765 = torch.reshape(t726, (2, 32, 1, 5, 128)) # t765: "cuda:0 f32[2, 32, 1, 5, 128]"
del t726
t727 = torch_slice_prim_impl(t725, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t727: "cuda:0 f32[2, 32, 5, 128]"
del t725
t730 = torch_slice_prim_impl(t724, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t730: "cuda:0 f32[2, 32, 5, 128]"
del t724
[t747, t764] = nvFusion3(t49, t51, t727, t730)
del t727, t730
t766 = torch.reshape(t747, (2, 32, 1, 5, 128)) # t766: "cuda:0 f32[2, 32, 1, 5, 128]"
del t747
t767 = torch.reshape(t764, (2, 32, 1, 5, 128)) # t767: "cuda:0 f32[2, 32, 1, 5, 128]"
del t764
t768 = torch.cat((t767, t766, t765), i73) # t768: "cuda:0 f32[2, 32, 3, 5, 128]"
del t767, t766, t765, i73
t769 = torch.permute(t768, (0, 3, 1, 2, 4)) # t769: "cuda:0 f32[2, 5, 32, 3, 128]"
del t768
t770 = torch.reshape(t769, (2, 5, 12288)) # t770: "cuda:0 f32[2, 5, 12288]"
del t769
t771 = torch.reshape(t770, (-1, 12288)) # t771: "cuda:0 f32[10, 12288]"
del t770
t775 = torch.permute(t771, (1, 0)) # t775: "cuda:0 f32[12288, 10]"
t777 = torch.matmul(t775, t776) # t777: "cuda:0 f32[12288, 4096]"
del t775, t776
t772 = torch.matmul(t771, t4) # t772: "cuda:0 f32[10, 4096]"
del t771, t4
t773 = torch.reshape(t772, (2, 5, 4096)) # t773: "cuda:0 f32[2, 5, 4096]"
del t772
[t780, t796] = nvFusion4(f56, t101, t104, t714, t773, t95)
del f56, t101, t104, t714, t773, t95
t797 = torch.reshape(t796, (-1, 4096)) # t797: "cuda:0 f32[10, 4096]"
t801 = torch.permute(t797, (1, 0)) # t801: "cuda:0 f32[4096, 10]"
t798 = torch.matmul(t797, t16) # t798: "cuda:0 f32[10, 11008]"
del t797, t16
t803 = torch.matmul(t801, t802) # t803: "cuda:0 f32[4096, 11008]"
del t801, t802
t799 = torch.reshape(t798, (2, 5, 11008)) # t799: "cuda:0 f32[2, 5, 11008]"
del t798
[t805, t813] = nvFusion5(t799, t86, t87)
del t799, t86, t87
t814 = torch.reshape(t805, (-1, 11008)) # t814: "cuda:0 f32[10, 11008]"
del t805
t818 = torch.permute(t814, (1, 0)) # t818: "cuda:0 f32[11008, 10]"
t821 = torch.reshape(t813, (-1, 11008)) # t821: "cuda:0 f32[10, 11008]"
del t813
t825 = torch.permute(t821, (1, 0)) # t825: "cuda:0 f32[11008, 10]"
t822 = torch.matmul(t821, t5) # t822: "cuda:0 f32[10, 4096]"
del t821, t5
t815 = torch.matmul(t814, t7) # t815: "cuda:0 f32[10, 4096]"
del t814, t7
t827 = torch.matmul(t825, t819) # t827: "cuda:0 f32[11008, 4096]"
del t825
t820 = torch.matmul(t818, t819) # t820: "cuda:0 f32[11008, 4096]"
del t818, t819
t816 = torch.reshape(t815, (2, 5, 4096)) # t816: "cuda:0 f32[2, 5, 4096]"
del t815
t823 = torch.reshape(t822, (2, 5, 4096)) # t823: "cuda:0 f32[2, 5, 4096]"
del t822
[t831, t847] = nvFusion6(f51, t75, t796, t81, t816, t823, t84)
del f51, t75, t796, t81, t816, t823, t84
t848 = torch.reshape(t847, (-1, 4096)) # t848: "cuda:0 f32[10, 4096]"
t852 = torch.permute(t848, (1, 0)) # t852: "cuda:0 f32[4096, 10]"
t849 = torch.matmul(t848, t15) # t849: "cuda:0 f32[10, 4096]"
del t848, t15
t854 = torch.matmul(t852, t853) # t854: "cuda:0 f32[4096, 4096]"
del t852, t853
t850 = torch.reshape(t849, (2, 5, 4096)) # t850: "cuda:0 f32[2, 5, 4096]"
del t849
t855 = torch.reshape(t850, (2, 5, 32, 128)) # t855: "cuda:0 f32[2, 5, 32, 128]"
del t850
t856 = torch.permute(t855, (0, 2, 1, 3)) # t856: "cuda:0 f32[2, 32, 5, 128]"
del t855
(t857, t858, t859, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t856, t65, t67, t43, None, t68, t69, t70, t71, f40, b41, scale=f42)
del t856, t65, t67, t43, t68, t69, t70, t71, f40, b41, f42
t900 = torch.reshape(t859, (2, 32, 1, 5, 128)) # t900: "cuda:0 f32[2, 32, 1, 5, 128]"
del t859
t863 = torch_slice_prim_impl(t857, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t863: "cuda:0 f32[2, 32, 5, 128]"
del t857
t860 = torch_slice_prim_impl(t858, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1]) # t860: "cuda:0 f32[2, 32, 5, 128]"
del t858
[t882, t899] = nvFusion7(t49, t51, t860, t863)
del t49, t51, t860, t863
t902 = torch.reshape(t899, (2, 32, 1, 5, 128)) # t902: "cuda:0 f32[2, 32, 1, 5, 128]"
del t899
t901 = torch.reshape(t882, (2, 32, 1, 5, 128)) # t901: "cuda:0 f32[2, 32, 1, 5, 128]"
del t882
t903 = torch.cat((t902, t901, t900), i23) # t903: "cuda:0 f32[2, 32, 3, 5, 128]"
del t902, t901, t900, i23
t904 = torch.permute(t903, (0, 3, 1, 2, 4)) # t904: "cuda:0 f32[2, 5, 32, 3, 128]"
del t903
t905 = torch.reshape(t904, (2, 5, 12288)) # t905: "cuda:0 f32[2, 5, 12288]"
del t904
t906 = torch.reshape(t905, (-1, 12288)) # t906: "cuda:0 f32[10, 12288]"
del t905
t910 = torch.permute(t906, (1, 0)) # t910: "cuda:0 f32[12288, 10]"
t907 = torch.matmul(t906, t3) # t907: "cuda:0 f32[10, 4096]"
del t906, t3
t912 = torch.matmul(t910, t911) # t912: "cuda:0 f32[12288, 4096]"
del t910, t911
t908 = torch.reshape(t907, (2, 5, 4096)) # t908: "cuda:0 f32[2, 5, 4096]"
del t907
[t915, t931] = nvFusion8(f6, t24, t30, t33, t847, t908)
del f6, t24, t30, t33, t847, t908
t932 = torch.torch.ops.aten.embedding_backward(t931, t0, i0, -1, b1, b2) # t932: "cuda:0 f32[32000, 4096]"
del t931, t0, i0, b1, b2
return (None, None, None, t912, t777, t827, t694, t820, t687, t645, t648, t915, t780, t831, t698, t854, t803, t721, t670, t932)
```
These traces are long, and require some familiarity with the model implementation to follow them, but they allow you to:
- Inspect exactly what operations are run including their decompositions.
- Inspect the sizes of tensors, their device, data type and conversions.
- Apply transformations to the traces since the computations are completely decoupled from the data.
- Inspect the backward operations generated for each forward operation to understand what autograd is doing.
### Transforms
Transforms are one of the core features of Thunder. For example, they enable easy data parallel distribution. That is replicated data parallelism (DDP) and fully-sharded data parallelism (FSDP).
We provide ready-to-use Fabric strategies that integrate Thunder DDP|FSDP. Under the hood, the code is quite straightforward:
```python
model = thunder.distributed.ddp(model)
# or
# model = thunder.distributed.fsdp(model)
model = thunder.jit(model)
```
After applying the DDP transformation, the backward trace will include the expected all-reduce collectives:
```python
p1022 = torch_all_reduce_prim_impl(t1021, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True, False) # p1022: "FUTURE cuda:0 f32[16797696]"
...
t1059 = torch_wait_prim_impl(p1025) # t1059: "cuda:0 f32[131072000]"
```
With `L.Fabric`, this is how to use them:
```python
from extensions.extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrategy
# fully-sharded data parallel
strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3",
bucketing_strategy="BLOCK",
executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"),
state_dict_type="full",
)
# replicated data parallel
strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"))
fabric = L.Fabric(devices=devices, strategy=strategy)
fabric.launch()
model = fabric.setup(model) # JIT is called here
```
And in the case of FSDP all-gathers in forward and reduce-scatters in backward.
Meaning that Thunder automatically introduced the necessary collective operations to support data parallelism.
### Executors
Thunder allows you to define a priority list of executors that can map operators:
```python
import thunder
model = thunder.jit(
model,
executors=["sdpa", "torchcompile_cat", "nvfuser", "torch"]
)
```
Notice how `torch.compile` is a valid executor. This executor registers a few operators with improved performance so that you can utilize the fastest set of operator implementations possible.
### Custom executors
Lightning Thunder provides extension points to integrate fast kernels for operators in your model without having to modify your implementation.
For instance, the [Unsloth project](https://github.com/unslothai/unsloth/) provides several Triton kernels that can be used with LitGPT:
- Cross entropy loss
- SwiGLU (part of `LLaMAMLP`)
- RoPE
The [`unsloth` directory](unsloth) contains a [custom executor](unsloth/executor.py) that registers these operators for LitGPT.
We can enable this executor by passing it to the list of executors available. The order matters because we want to run its custom operators before
`NvFuser` creates its fusion regions.
```python
import thunder
model = thunder.jit(
model,
executors=["sdpa", "unsloth", "torchcompile_cat", "nvfuser", "torch"]
)
```
Doing this, the model trace now includes the Unsloth kernel calls:
```python
def augmented_forward_fn(*args):
...
(t121, _, _, _, _, _) = unsloth_apply_rope(t120, t21, t22)
...
(t189, t190) = unsloth_cross_entropy(t187, t188)
...
def backward_fn(saved_for_backward, cotangents):
...
t652 = unsloth_cross_entropy_backward(t651, t187, t188, t190) # t652: "cuda:0 f32[6, 320]"
...
t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4) # t763: "cuda:0 f32[2, 4, 3, 16]"
```
We provide a specific [pre-training script copy](pretrain.py) that uses this executor.
Given the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like [NvFuser](https://github.com/NVIDIA/Fuser).
## Examples and benchmarks
> [!WARNING]
> Lightning Thunder is alpha and not ready for production runs. Feel free to try it out, expect a few bumps along the way.
> We expect speed and memory usage to improve as we continue to develop it.
We provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM.
| Setting | Compiler | Executors | Devices | ms/iter @ step 10 | Memory (GB) |
|----------------------|----------|----------------------------------------|---------|-------------------|---------------|
| Fully-sharded ZeRO 3 | Eager | - | 8 | 456.57 | 22.13 |
| Fully-sharded ZeRO 3 | torch | - | 8 | Not supported | Not supported |
| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile | 8 | Not supported | Not supported |
| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 333.56 | 21.40 |
| | | | | | |
| Replicated | Eager | - | 8 | 569.46 | 32.04 |
| Replicated | torch | - | 8 | Not supported | Not supported |
| Replicated | Thunder | sdpa, torchcompile | 8 | 426.44 | 22.19 |
| Replicated | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 356.01 | 27.42 |
| | | | | | |
| - | Eager | - | 1 | 447.65 | 29.84 |
| - | torch | - | 1 | Not supported | Not supported |
| - | Thunder | sdpa, torchcompile | 1 | 373.37 | 22.19 |
| - | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 322.25 | 27.42 |
| | | | | | |
| Unsloth | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 331.92 | 25.19 |
Reproduction details
Config:
```yaml
out_dir: out/pretrain-thunder
data: TinyStories
tokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0
logger_name: csv
```
Commands:
```bash
litgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true
python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1
```
`--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601
The CUDA devices are all NVIDIA A100-SXM4-40GB.
```text
Python version: 3.10.12 [GCC 11.4.0] (64-bit runtime)
Is debug build: False
CUDA used to build PyTorch: 12.1
CUDA runtime version: 12.3.107
Nvidia driver version: 545.23.08
pytorch-triton==3.0.0+45fff310c8
torch==2.4.0.dev20240427+cu121
lightning==2.3.0.dev20240328
lightning-thunder==0.2.0.dev20240505
nvfuser_cu121==0.2.3.dev20240428
```
================================================
FILE: extensions/thunder/__init__.py
================================================
import sys
from pathlib import Path
# support running without installing as a package, adding extensions to the Python path
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
================================================
FILE: extensions/thunder/pretrain.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import math
import os
import pprint
import sys
import time
from dataclasses import asdict
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import lightning as L
import torch
import torch.nn as nn
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops
from torch.utils.data import DataLoader
from torchmetrics.aggregation import RunningMean
from typing_extensions import Literal
from litgpt import Tokenizer
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention
from litgpt.parser_config import save_hyperparameters
from litgpt.types import LoggerChoice
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
find_resume_path,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
reset_parameters,
save_config,
)
# support running without installing as a package
wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))
def forward_and_loss(model: nn.Module, input_ids: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
logits = model(input_ids)
# disable chunk_size to enable the unsloth cross entropy kernel
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
return loss
def setup(
model_name: Optional[str] = None,
model_config: Optional[Config] = None,
out_dir: Path = Path("out/pretrain"),
initial_checkpoint_dir: Optional[Path] = None,
resume: Union[bool, Literal["auto"], Path] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=512,
micro_batch_size=4,
max_tokens=int(3e12), # 3 trillion
max_norm=1.0,
min_lr=4e-5,
lr_warmup_steps=2000,
tie_embeddings=False,
),
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
devices: Union[int, str] = "auto",
num_nodes: int = 1,
tokenizer_dir: Optional[Path] = None,
logger_name: LoggerChoice = "tensorboard",
seed: int = 42,
compiler: Optional[Literal["thunder", "torch"]] = "thunder",
executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"),
strategy: Literal["auto", "ddp", "fsdp"] = "fsdp",
):
"""Pretrain a model.
Arguments:
model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
``model_config``.
model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
``model_config``.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
Useful for continued pretraining. Mutually exclusive with ``resume``.
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
devices: How many devices/GPUs to use. Uses all GPUs by default.
num_nodes: How many nodes the code is being run on.
tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
module require this.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
compiler: If desired, the compiler/JIT to use.
executors: If using Thunder, the executors to enable.
strategy: If desired, the strategy to use.
"""
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
elif model_config is None and model_name is None:
model_name = "tiny-llama-1.1b"
config = Config.from_name(model_name) if model_config is None else model_config
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
# in case the dataset requires the Tokenizer
tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None
logger = choose_logger(
logger_name,
out_dir,
name=f"pretrain-{config.name}",
resume=bool(resume),
log_interval=train.log_interval,
log_args=asdict(log),
)
if devices * num_nodes > 1:
if compiler == "thunder":
if strategy == "fsdp":
from extensions.thunder.strategies import ThunderFSDPStrategy
strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3",
bucketing_strategy="BLOCK",
state_dict_type="full",
jit=False,
)
elif strategy == "ddp":
from extensions.thunder.strategies import ThunderDDPStrategy
strategy = ThunderDDPStrategy(jit=False)
else:
if strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="FULL_SHARD"
)
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision="bf16-true", loggers=[logger])
fabric.launch()
if compiler is not None:
global forward_and_loss
forward_and_loss = (
jit(forward_and_loss, executors) if compiler == "thunder" else torch.compile(forward_and_loss)
)
fabric.print(pprint.pformat(hparams))
if logger_name in ("tensorboard", "wandb", "mlflow"):
fabric.logger.log_hyperparams(hparams)
main(
fabric=fabric,
devices=devices,
num_nodes=num_nodes,
seed=seed,
initial_checkpoint_dir=initial_checkpoint_dir,
resume=resume,
config=config,
data=data,
out_dir=out_dir,
tokenizer_dir=tokenizer_dir,
tokenizer=tokenizer,
train=train,
eval=eval,
optimizer=optimizer,
compiler=compiler,
)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
initial_checkpoint_dir: Optional[Path],
resume: Union[bool, Literal["auto"], Path],
config: Config,
data: DataModule,
out_dir: Path,
tokenizer_dir: Optional[Path],
tokenizer: Optional[Tokenizer],
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
compiler: Optional[Literal["thunder", "torch"]],
num_nodes: int = 1,
) -> None:
validate_args(train, eval, initial_checkpoint_dir, resume)
if fabric.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd)
if train.tie_embeddings:
model.transformer.wte.weight = model.lm_head.weight
if train.max_seq_length:
model.max_seq_length = train.max_seq_length
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters: {num_parameters(model):,}")
model = fabric.setup(model)
if compiler == "thunder":
# avoid `Tensor.register_hook` which is unsupported
model._register_backward_hook = lambda *_: None
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
if initial_checkpoint_dir:
fabric.load_raw(initial_checkpoint_dir / "lit_model.pth", model)
state = {
"model": model,
"optimizer": optimizer,
"train_dataloader": train_dataloader,
"iter_num": 0,
"step_count": 0,
}
resume = find_resume_path(resume, out_dir)
if resume:
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
train_time = time.perf_counter()
fit(
fabric=fabric,
devices=devices,
num_nodes=num_nodes,
state=state,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
out_dir=out_dir,
tokenizer_dir=tokenizer_dir,
train=train,
eval=eval,
optimizer=optimizer,
)
fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s")
# Save final checkpoint
save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
def fit(
fabric: L.Fabric,
devices: int,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
out_dir: Path,
tokenizer_dir: Optional[Path],
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
model = state["model"]
optimizer = state["optimizer"]
validate(fabric, model, val_dataloader, max_iters=2) # sanity check
throughput = ThroughputMonitor(fabric, window_size=5)
with torch.device("meta"):
meta_model = GPT(model.config)
x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length))
model_fwd = lambda: meta_model(x) # noqa: F821
model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821
measured_flops = measure_flops(meta_model, model_fwd, model_loss)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
max_tokens_per_device = train.max_tokens // fabric.world_size
tokens_per_iter = train.micro_batch_size * model.max_seq_length
max_iters = max_tokens_per_device // tokens_per_iter
log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes)
initial_iter = state["iter_num"]
train_iterator = CycleIterator(train_dataloader)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
total_t0 = time.perf_counter()
val_loss = "n/a"
warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader)
for train_data in train_iterator:
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
state["iter_num"] += 1
iter_t0 = time.perf_counter()
input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long()
targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long()
is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
loss = forward_and_loss(model, input_ids, targets)
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
# THUNDER unsupported: https://github.com/Lightning-AI/lightning-thunder/issues/2357
# fabric.clip_gradients(model, optimizer, max_norm=train.max_norm)
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1
if state["iter_num"] % log_iter_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=(t1 - total_t0),
flops=(measured_flops * log_iter_interval),
batches=state["iter_num"],
samples=(state["iter_num"] * train.micro_batch_size),
lengths=(state["iter_num"] * train.micro_batch_size * model.max_seq_length),
)
metrics = {
"loss": loss,
"iter": state["iter_num"],
"step": state["step_count"],
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"remaining_time": (
(t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"])
),
"tokens": state["iter_num"] * train.micro_batch_size * model.max_seq_length,
"total_tokens": (state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size),
"learning_rate": lr,
}
if isinstance(val_loss, float):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
f" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}"
)
throughput_metrics = throughput.compute()
metrics.update(throughput_metrics)
fabric.log_dict(metrics, step=state["iter_num"] - 1)
if val_dataloader is not None and not is_accumulating and state["step_count"] % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)
val_loss = val_loss.item()
td = time.perf_counter() - t0
fabric.print(f"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms")
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=state["iter_num"] - 1)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:
save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth")
@torch.no_grad()
def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor:
fabric.barrier()
fabric.print("Validating ...")
model.eval()
losses = []
for k, batch in enumerate(val_dataloader):
if k >= max_iters:
break
input_ids = batch[:, 0 : model.max_seq_length].contiguous().long()
targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long()
loss = forward_and_loss(model, input_ids, targets)
losses.append(loss)
val_loss = torch.stack(losses).mean()
model.train()
fabric.barrier()
return val_loss
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with linear warmup)
def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > max_iters, return min learning rate
if it > max_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
def initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None:
"""GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745)."""
# Adapted from https://github.com/jzhang38/TinyLlama
def init_weights(module, std):
nn.init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
nn.init.zeros_(module.bias)
for mod in model.modules():
if isinstance(mod, (nn.Embedding, nn.Linear)):
mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd))
# need a separate loop because `mod.proj` below is a `nn.Linear` too
for mod in model.modules():
if isinstance(mod, (LLaMAMLP, CausalSelfAttention, MultiheadLatentAttention)):
mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer))
if not isinstance(fabric.strategy, FSDPStrategy):
reset_parameters(model)
def init_out_dir(out_dir: Path) -> Path:
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
return out_dir
def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):
model = state["model"]
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)
def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:
issues = []
unsupported = [(train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["max_tokens", "max_norm"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if initial_checkpoint_dir and resume:
issues.append("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.")
if issues:
raise ValueError("\n".join(issues))
def jit(fn: Callable, executors: List[str]) -> Any:
assert executors is not None
from unsloth.executor import unsloth_ex # import for registration # noqa: F401
import thunder
return thunder.jit(fn, executors=executors)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
CLI(setup)
================================================
FILE: extensions/thunder/strategies/__init__.py
================================================
from .thunder_ddp import ThunderDDPStrategy # noqa: F401
from .thunder_fsdp import ThunderFSDPStrategy # noqa: F401
================================================
FILE: extensions/thunder/strategies/thunder_ddp.py
================================================
"""Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually."""
from contextlib import nullcontext
from datetime import timedelta
from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from typing_extensions import override
from litgpt.constants import _THUNDER_AVAILABLE
if TYPE_CHECKING:
from thunder import Executor
class ThunderDDPStrategy(ParallelStrategy):
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
jit: bool = True,
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
**kwargs: Any,
):
r"""Strategy for Replicated Data Parallel provided by Lightning Thunder.
.. warning:: This is an :ref:`experimental ` feature.
Arguments:
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
jitting a function that includes the model.
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
or the actual executor instances.
\**kwargs: See available parameters in :func:`thunder.distributed.ddp`.
"""
if not _THUNDER_AVAILABLE:
raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self.parallel_devices = parallel_devices
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = executors
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()
self._ddp_kwargs = kwargs
@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_nodes(self) -> int:
return self._num_nodes
@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank}
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
@override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
@override
def setup_module(self, module: Module) -> Module:
import thunder
if (cd := thunder.compile_data(module)) is not None:
# the module was already jitted
if thunder.compile_stats(module).last_traces is not None:
raise RuntimeError(
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" DDP transform. Remove the `forward` call before `fabric.setup()`"
)
assert cd.is_module # sanity check
ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
# update the compile data state
cd.fn = ddp_module
cd.process_group_for_ddp = ddp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.ddp(module, **self._ddp_kwargs)
if not self.jit:
return module
return thunder.jit(module, executors=self.executors)
@override
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=[self.root_device.index])
else:
torch.distributed.barrier()
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj
obj = [obj]
torch.distributed.broadcast_object_list(obj, src)
return obj[0]
def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
def _set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
class _ThunderDataParalellBackwardSyncControl(_BackwardSyncControl):
def __init__(self):
self._enabled = False
@override
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""
In Thunder, we cannot use ``module.no_sync()`` because reduction happens at the end of the context manager.
It assumes that the user will reuse it across all gradient accumulation iterations:
.. code-block:: python
with model.no_sync():
for _ in range(len(gradient_accumulation_iters)):
fwd()
bwd() # uses no-sync-backward trace
fwd()
bwd() # uses regular-backward trace
However, Fabric is designed to the context manager every iteration:
.. code-block:: python
for i in range(iters):
is_accumulating = (i + 1) % gradient_accumulation_iters != 0
ctx = model.no_sync() if is_accumulating else nullcontext()
with ctx:
fwd()
bwd()
So we need to be smart about when to sync grads based on the ``enabled`` value.
More info in https://github.com/Lightning-AI/lit-thunder-LEGACY/issues/2085
"""
if not getattr(module, "use_ddp", False) and not getattr(module, "use_fsdp", False):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is applied DDP or FSDP."
f" Got: {module.__class__.__name__}."
)
from thunder.distributed import skip_data_parallel_grad_sync
previous, self._enabled = self._enabled, enabled
if enabled:
return skip_data_parallel_grad_sync()
if not enabled and previous:
return _SyncGradsContextManager(module)
return nullcontext()
class _SyncGradsContextManager:
def __init__(self, module: Module) -> None:
self._module = module
@override
def __enter__(self) -> None:
from thunder.distributed import _sync_grads
_sync_grads(self._module)
@override
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
pass
================================================
FILE: extensions/thunder/strategies/thunder_fsdp.py
================================================
"""Fabric Strategy to support Thunder FSDP: To be upstreamed into Fabric eventually."""
import shutil
from contextlib import ExitStack, nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Tuple, Union
import torch
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.strategy import TBroadcast, _apply_filter, _Sharded, _validate_keys_for_strict_loading
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.load import _METADATA_FILENAME, _move_state_into
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, _Stateful
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override
from extensions.thunder.strategies.thunder_ddp import _ThunderDataParalellBackwardSyncControl
from litgpt.constants import _THUNDER_AVAILABLE
if TYPE_CHECKING:
from thunder import Executor
from thunder.distributed import FSDPBucketingStrategy, FSDPType
from thunder.distributed.checkpoint import StateDictOptions
_FSDP_TYPE = Union[FSDPType, Literal["ZERO2", "ZERO3"]]
_BUCKETING_STRATEGY = Union[FSDPBucketingStrategy, Literal["NONE", "LAYER", "BLOCK"]]
class ThunderFSDPStrategy(ParallelStrategy, _Sharded):
def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
jit: bool = True,
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
sharding_strategy: "_FSDP_TYPE" = "ZERO3",
bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE",
state_dict_type: Literal["full", "sharded"] = "sharded",
**kwargs: Any,
):
r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.
.. warning:: This is an :ref:`experimental ` feature.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically.
Arguments:
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
jitting a function that includes the model.
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
or the actual executor instances.
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination
of them:
- ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default).
- ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated.
Also accepts a :class:`thunder.distributed.FSDPType` enum value.
bucketing_strategy: Enables combining the collective operations for sets of layers.
- ``"NONE"``: No bucketing (default).
- ``"LAYER"``: Create buckets per layer class.
- ``"BLOCK"``: Create buckets per layer block.
Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file
(default).
- ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
a folder with as many files as the world size.
\**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.
"""
if not _TORCH_GREATER_EQUAL_2_2:
raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.")
if not _THUNDER_AVAILABLE:
raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self.parallel_devices = parallel_devices
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
from thunder.distributed import FSDPBucketingStrategy, FSDPType
self.sharding_strategy = (
FSDPType[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy
)
self.bucketing_strategy = (
FSDPBucketingStrategy[bucketing_strategy.upper()]
if isinstance(bucketing_strategy, str)
else bucketing_strategy
)
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = executors
self._state_dict_type = state_dict_type
self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()
self._fsdp_kwargs = kwargs
@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_nodes(self) -> int:
return 1
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank}
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
@override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
@override
def setup_module(self, module: Module) -> Module:
import thunder
if (cd := thunder.compile_data(module)) is not None:
# the module was already jitted
if thunder.compile_stats(module).last_traces is not None:
raise RuntimeError(
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" FSDP transform. Remove the `forward` call before `fabric.setup()`"
)
assert cd.is_module # sanity check
fsdp_module = thunder.distributed.fsdp(
cd.fn,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
# update the compile data state
cd.fn = fsdp_module
cd.process_group_for_ddp = fsdp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.fsdp(
module,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
if not self.jit:
return module
return thunder.jit(module, executors=self.executors)
@override
def module_to_device(self, module: Module) -> None:
pass
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
stack = ExitStack()
if empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP
stack.enter_context(torch.device("meta"))
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
@override
def module_sharded_context(self) -> ContextManager:
return nullcontext()
@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=[self.root_device.index])
else:
torch.distributed.barrier()
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj
obj = [obj]
torch.distributed.broadcast_object_list(obj, src)
return obj[0]
@override
def clip_gradients_norm(
self,
module: Module,
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> Tensor:
raise NotImplementedError
@override
def save_checkpoint(
self,
path: _PATH,
state: Dict[str, Union[Module, Optimizer, Any]],
storage_options: Optional[Any] = None,
filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
) -> None:
if storage_options is not None:
raise TypeError(
"`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
" `FSDPStrategy` does not use the `CheckpointIO`."
)
if filter is not None:
raise NotImplementedError("Filtering checkpoint paths is not implemented")
# broadcast the path from rank 0 to ensure all the states are saved in a common path
path = Path(self.broadcast(path))
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, save
modules = [module for module in state.values() if has_fsdp_modules(module)]
if len(modules) == 0:
raise ValueError(
"Could not find a FSDP model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
)
if len(modules) > 1:
raise ValueError(
"Found multiple FSDP models in the given state. Saving checkpoints with FSDP is"
" currently limited to a single model per checkpoint. To save multiple models, call the"
" save method for each model separately with a different path."
)
if self._state_dict_type == "sharded":
if _is_full_checkpoint(path):
path.unlink()
path.mkdir(parents=True, exist_ok=True)
options = StateDictOptions(full_state_dict=False, cpu_offload=True, rank0_only=False)
converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank)
save(converted_state, path)
if self.global_rank == 0:
torch.save(metadata, path / _METADATA_FILENAME)
elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
options = StateDictOptions(full_state_dict=True, cpu_offload=True, rank0_only=True)
converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank)
converted_state.update(metadata)
if self.global_rank == 0:
torch.save(converted_state, path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
@override
def load_checkpoint(
self,
path: _PATH,
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
if not state:
raise ValueError(
f"Got `FSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least"
" a model instance to reload is required. Pass it in like so:"
" `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`"
)
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(path))
from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, load, load_model_state_dict
if isinstance(state, Module):
if not _is_full_checkpoint(path):
raise ValueError(
"Failed to load checkpoint directly into the model. The given path must be a single file"
f" containing the full state dict: {path}"
)
state_dict = torch.load(str(path), mmap=True, map_location="cpu")
options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False)
load_model_state_dict(state_dict, _unwrap_tom(state), options, self.local_rank)
return {}
if isinstance(state, Optimizer):
raise NotImplementedError(
"Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy."
)
modules = {key: module for key, module in state.items() if has_fsdp_modules(module)}
if len(modules) == 0:
raise ValueError(
"Could not find a FSDP model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
)
if len(modules) > 1:
raise ValueError(
"Found multiple FSDP models in the given state. Loading checkpoints with FSDP is"
" currently limited to a single model per checkpoint. To load multiple models, call the"
" load method for each model separately with a different path."
)
optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)}
module_key, module = list(modules.items())[0]
module = _unwrap_tom(module)
if _is_sharded_checkpoint(path):
options = StateDictOptions(full_state_dict=False, cpu_offload=True, strict=strict, rank0_only=False)
# Load the DCP state dict, which requires a holder state dict
converted_state, _ = _get_state_dict(state, None, options, self.local_rank)
load(converted_state, path)
load_model_state_dict(converted_state[module_key], module, options, self.local_rank)
# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
for key in requested_metadata_keys:
if key not in metadata:
continue
state[key] = metadata.pop(key)
# return the remaining metadata that wasn't requested as part of `state`
return metadata
if _is_full_checkpoint(path):
options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False)
if not options.rank0_only or self.local_rank == 0:
map_location = "cpu" if options.cpu_offload else None
checkpoint = torch.load(str(path), mmap=True, map_location=map_location)
load_model_state_dict(checkpoint[module_key], module, options, self.local_rank)
else:
checkpoint = {}
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
# Load metadata (anything not a module or optimizer)
_move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys)
# return the remaining metadata that wasn't requested as part of `state`
return checkpoint
raise ValueError(
f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"
" directory with FSDP checkpoint shards, or a single file with a full checkpoint."
)
def _setup_distributed(self) -> None:
reset_seed()
self._set_world_ranks()
process_group_backend = _get_default_process_group_backend_for_device(self.root_device)
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, process_group_backend)
def _set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
def _is_sharded_checkpoint(path: Path) -> bool:
"""A heuristic check to determine whether the path points to a directory with checkpoint shards."""
return path.is_dir() and (path / _METADATA_FILENAME).is_file()
def _is_full_checkpoint(path: Path) -> bool:
return path.is_file()
def _get_state_dict(
state: Dict[str, Any],
filter: Optional[Dict[str, Callable[[str, Any], bool]]],
options: "StateDictOptions",
rank: int,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
from thunder.distributed.checkpoint import get_model_state_dict
# replace the modules and optimizer objects in the state with their local state dict
# and separate the user's metadata
converted_state: Dict[str, Any] = {}
metadata: Dict[str, Any] = {}
for key, obj in state.items():
converted: Any
if isinstance(obj, Module):
converted = get_model_state_dict(_unwrap_tom(obj), options, rank)
target_dict = converted_state
elif isinstance(obj, Optimizer):
# TODO: optimizer support
converted = obj.state_dict()
target_dict = converted_state
else: # everything not a module or optimizer is considered metadata
converted = obj.state_dict() if isinstance(obj, _Stateful) else obj
target_dict = metadata
_apply_filter(key, filter or {}, converted, target_dict)
return converted_state, metadata
def _unwrap_tom(obj: object) -> object:
# TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder
from thunder import ThunderModule
if isinstance(obj, ThunderModule):
return obj._model
return obj
================================================
FILE: extensions/thunder/unsloth/__init__.py
================================================
================================================
FILE: extensions/thunder/unsloth/executor.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
from pathlib import Path
from typing import Optional, Tuple
import torch
from torch import Tensor
import litgpt.model
from litgpt.constants import _THUNDER_AVAILABLE
from litgpt.model import LLaMAMLP as OriginalLLaMAMLP
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import get_grad, mean_backward, put_grads
from thunder.extend import OperatorExecutor, register_executor
from thunder.torch import ne, sum, true_divide
if _THUNDER_AVAILABLE:
import thunder
import thunder.torch as ltorch
sys.path.append(str(Path(__file__).parent))
import kernels
unsloth_ex = OperatorExecutor("unsloth", version="0.1")
register_executor(unsloth_ex)
"""
====================
Cross Entropy Loss
====================
"""
def unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tuple[TensorProxy, TensorProxy]:
return (
TensorProxy(
shape=(logits.shape[0],),
# the cross entropy kernel only supports float32
dtype=thunder.dtypes.float32,
device=logits.device,
requires_grad=logits.requires_grad,
),
TensorProxy(shape=(logits.shape[0],), dtype=thunder.dtypes.float32, device=logits.device, requires_grad=False),
)
unsloth_cross_entropy = unsloth_ex.register_operator(
"unsloth_cross_entropy", meta=unsloth_cross_entropy_meta, fn=kernels.cross_entropy_loss._cross_entropy_forward_impl
)
def unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor:
# clone() because the kernel writes the grads in the logits
return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels)
def unsloth_cross_entropy_backward_meta(
dlosses: TensorProxy, logits: TensorProxy, logsumexp: TensorProxy, labels: TensorProxy
) -> TensorProxy:
return thunder.TensorProxy(like=logits)
unsloth_cross_entropy_backward = unsloth_ex.register_operator(
"unsloth_cross_entropy_backward", meta=unsloth_cross_entropy_backward_meta, fn=unsloth_cross_entropy_backward_impl
)
def unsloth_cross_entropy_checker(
logits: TensorProxy,
labels: TensorProxy,
weight: Optional[TensorProxy] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
) -> bool:
return (
weight is None
and size_average is None
and reduce is None
and reduction in ("none", "mean")
and ignore_index == -100
and label_smoothing == 0.0
and logits.device.type == "cuda"
and labels.device.type == "cuda"
)
def cross_entropy_to_unsloth(
logits: TensorProxy,
labels: TensorProxy,
weight: Optional[TensorProxy] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
) -> Tuple[TensorProxy, TensorProxy]:
loss, logsumexp = unsloth_cross_entropy(logits, labels)
if reduction == "mean":
# "mean" reduction is not part of the kernel
# TODO: this doesn't consider that all elements could be masked, causing a division by 0
n_items = sum(ne(labels, -100))
loss = true_divide(sum(loss), n_items)
elif reduction != "none":
raise NotImplementedError(reduction)
return loss, logsumexp
def unsloth_cross_entropy_grad(
logits: TensorProxy,
labels: TensorProxy,
weight: Optional[TensorProxy] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
) -> TensorProxy:
loss, logsumexp = cross_entropy_to_unsloth(**locals())
grad = get_grad(loss)
if reduction == "mean":
grad = mean_backward(logsumexp.ndim, logsumexp.shape, (0,), grad)
logits_grad = unsloth_cross_entropy_backward(grad, logits, labels, logsumexp)
put_grads((logits,), (logits_grad,))
return loss
# registers as cross entropy implementation, including the execution transform and now a grad transform
unsloth_ex.register_implementation(
ltorch.cross_entropy,
checker=unsloth_cross_entropy_checker,
execution_transform=lambda *args: cross_entropy_to_unsloth(*args)[0],
grad_transform=unsloth_cross_entropy_grad,
)
"""
=========
RMSNorm
=========
The RMSNorm kernel is not integrated because it's not numerically equal and it doesn't compute the gradient for the
weight, just for the input.
"""
"""
========
SwiGLU
========
"""
def swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(e) * g
class ThunderLLaMAMLP(OriginalLLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = swiglu(x_fc_1, x_fc_2)
return self.proj(x)
litgpt.model.LLaMAMLP = ThunderLLaMAMLP
def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy:
return TensorProxy(like=e)
litgpt_swiglu = unsloth_ex.register_operator("litgpt_swiglu", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu)
unsloth_swiglu_forward = unsloth_ex.register_operator(
"unsloth_swiglu_forward", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args)
)
def unsloth_swiglu_backward_meta(DW: TensorProxy, e: TensorProxy, g: TensorProxy) -> Tuple[TensorProxy, TensorProxy]:
return TensorProxy(like=g), TensorProxy(like=e)
def unsloth_swiglu_backward_fn(DW: Tensor, e: Tensor, g: Tensor) -> Tuple[Tensor, Tuple]:
B, T, n_embd = e.shape
e = e.view(-1, n_embd)
g = g.view(-1, n_embd)
DW, e, g = kernels.swiglu_DWf_DW_dfg_kernel(DW, e, g)
e = e.view(B, T, n_embd)
g = g.view(B, T, n_embd)
return g, e
unsloth_swiglu_backward = unsloth_ex.register_operator(
"unsloth_swiglu_backward", meta=unsloth_swiglu_backward_meta, fn=unsloth_swiglu_backward_fn
)
def swiglu_to_unsloth_checker(e: TensorProxy, g: TensorProxy) -> bool:
return e.device.type == "cuda" and g.device.type == "cuda"
def unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy:
h = unsloth_swiglu_forward(**locals())
grad = get_grad(h)
e_grad, g_grad = unsloth_swiglu_backward(grad, e, g)
put_grads((e, g), (e_grad, g_grad))
return h
unsloth_ex.register_implementation(
litgpt_swiglu,
checker=swiglu_to_unsloth_checker,
execution_transform=unsloth_swiglu_forward,
grad_transform=unsloth_swiglu_grad,
)
"""
======
RoPE
======
"""
def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
return TensorProxy(like=x)
apply_rope = unsloth_ex.register_operator(
"litgpt_apply_rope", like=apply_rope_meta, fn=litgpt.model.apply_rope, replaces=litgpt.model.apply_rope
)
def unsloth_apply_rope_meta(
Q: TensorProxy, cos: TensorProxy, sin: TensorProxy
) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]:
batch, n_heads, seq_len, head_dim = Q.shape
assert seq_len <= cos.shape[-2]
BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2)
div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE)
n_groups = div + (mod != 0)
return TensorProxy(like=Q), cos, sin, n_groups, BLOCK_SIZE, num_warps
unsloth_apply_rope = unsloth_ex.register_operator(
"unsloth_apply_rope", meta=unsloth_apply_rope_meta, fn=kernels._rope_embedding_forward_impl
)
def unsloth_apply_rope_backward_meta(
dY: TensorProxy, cos: TensorProxy, sin: TensorProxy, n_groups: int, BLOCK_SIZE: int, num_warps: int
) -> TensorProxy:
return TensorProxy(like=dY)
unsloth_apply_rope_backward = unsloth_ex.register_operator(
"unsloth_apply_rope_backward", meta=unsloth_apply_rope_backward_meta, fn=kernels._rope_embedding_backward_impl
)
def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:
return len(x.shape) == 4 and x.device.type == "cuda" and cos.device.type == "cuda" and sin.device.type == "cuda"
def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
Q, cos, sin, n_groups, BLOCK_SIZE, num_warps = unsloth_apply_rope(x, cos, sin)
dY = get_grad(Q)
dX = unsloth_apply_rope_backward(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps)
put_grads((x,), (dX,))
return Q
unsloth_ex.register_implementation(
apply_rope,
checker=apply_rope_to_unsloth_checker,
execution_transform=lambda *args: unsloth_apply_rope(*args)[0],
grad_transform=unsloth_apply_rope_grad,
)
================================================
FILE: extensions/thunder/unsloth/kernels/__init__.py
================================================
from .cross_entropy_loss import _cross_entropy_backward_impl, _cross_entropy_forward_impl # noqa: F401
from .rope_embedding import ROPE_GROUP_SIZE, _rope_embedding_backward_impl, _rope_embedding_forward_impl # noqa: F401
from .swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel # noqa: F401
from .utils import calculate_settings # noqa: F401
================================================
FILE: extensions/thunder/unsloth/kernels/cross_entropy_loss.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from litgpt.constants import _TRITON_AVAILABLE
from .utils import MAX_FUSED_SIZE, calculate_settings
if _TRITON_AVAILABLE:
import triton
import triton.language as tl
@triton.jit
def _cross_entropy_forward(
logits_ptr,
logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
logsumexp is also stable
Take y = log[sum(exp(x))]
exp(y) = sum(exp(x))
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
exp(y) = exp(c)*sum(exp(x - c))
y = log(exp(c)*sum(exp(x - c)))
y = c + log[sum(exp(x - c))]
This means we can set c = max(x) to make sure
exp(x - c) always is exp(x - max(x)).
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = logsumexp - x
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr,
logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
N_CHUNKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
256K vocab divided in 4 chunks
|-65536-| |-65536-| |-65536-| |-65536-|
|-------| |-------| |-------| |-------|
|-------| |-------| |-------| |-------|
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
Notice we can do logsumexp for each chunk and then
logsumexp[chunk_sum(logsumexp)] == logsumexp
chunk_sum = log[chunk_sum(logsumexp)]
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
= log[sum(exp(a)) + ... + sum(exp(z))]
= logsumexp(x)
This means we can perform a logsumexp for each chunk, then do a
final logsumexp reduction!
Ie do: logsumexp(chunked_logsumexp) - x
"""
row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if chunk_idx == 0:
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = -1.0 * x
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
tl.store(logsumexp_ptr, logsumexp)
pass
@triton.jit
def _cross_entropy_backward(
logits_ptr,
logits_row_stride,
dloss_ptr,
dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
block_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
dloss_ptr += row_idx * dloss_row_stride
col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)
pass
def _cross_entropy_forward_impl(logits, labels):
n_rows, vocab_size = logits.shape
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype=torch.float32, device="cuda")
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype=torch.float32, device="cuda")
_cross_entropy_forward[(n_rows,)](
logits,
logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty(
(
n_rows,
n_chunks,
),
dtype=torch.float32,
device="cuda",
)
_chunked_cross_entropy_forward[
(
n_rows,
n_chunks,
)
](
logits,
logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
N_CHUNKS=n_chunks,
BLOCK_SIZE=MAX_FUSED_SIZE,
num_warps=32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum
losses += logsumexp
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
return losses, logsumexp
def _cross_entropy_backward_impl(dlosses, logits, logsumexp, labels):
n_rows, vocab_size = logits.shape
BLOCK_SIZE = 4096
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks = div + (mod != 0)
_cross_entropy_backward[
(
n_rows,
n_blocks,
)
](
logits,
logits.stride(0),
dlosses,
dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=8,
)
return logits
================================================
FILE: extensions/thunder/unsloth/kernels/rope_embedding.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from litgpt.constants import _TRITON_AVAILABLE
from .utils import calculate_settings
if _TRITON_AVAILABLE:
import triton
import triton.language as tl
ROPE_GROUP_SIZE = 4
@triton.heuristics(
{
"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],
}
)
@triton.jit
def _rope_embedding(
Q,
Q_row_stride,
cos,
cos_row_stride,
sin,
sin_row_stride,
seqlen,
head_dim: tl.constexpr,
n_heads: tl.constexpr,
BACKWARD_PASS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
ROPE_GROUP_SIZE: tl.constexpr = 4,
):
"""
Calculates the RoPE Embedding quickly
RoPE is Q * cos + rotate_half(Q) * sin
See our blog post for more info
"""
row_position = tl.program_id(0)
group_head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
sin1 = tl.load(sin + (row_position % seqlen) * sin_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0)
cos1 = tl.load(cos + (row_position % seqlen) * cos_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
head_start = group_head_position * ROPE_GROUP_SIZE
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
for k in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask)
tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask)
pass
pass
def _rope_embedding_forward_impl(Q, cos, sin):
Q = Q.transpose(1, 2).clone()
cos, sin = cos.squeeze(), sin.squeeze()
batch, seq_len, n_heads, head_dim = Q.shape
Q = Q.reshape(batch * seq_len, n_heads * head_dim)
n_rows, n_cols = Q.shape
assert seq_len <= cos.shape[0]
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2) # (head_dim//2)
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
n_groups = div + (mod != 0)
_rope_embedding[
(
n_rows,
n_groups,
)
](
Q,
Q.stride(0),
cos,
cos.stride(0),
sin,
sin.stride(0),
seq_len,
head_dim,
n_heads,
BACKWARD_PASS=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
Q = Q.view(batch, seq_len, n_heads, head_dim)
Q = Q.transpose(1, 2)
return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps
def _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps):
dY = dY.transpose(1, 2)
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch * seq_len, n_heads * head_dim)
# Must be reshape not view
n_rows, n_cols = dY.shape
_rope_embedding[
(
n_rows,
n_groups,
)
](
dY,
dY.stride(0),
cos,
cos.stride(0),
sin,
sin.stride(0),
seq_len,
head_dim,
n_heads,
BACKWARD_PASS=True,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
dY = dY.transpose(1, 2)
return dY
================================================
FILE: extensions/thunder/unsloth/kernels/swiglu.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from litgpt.constants import _TRITON_AVAILABLE
if _TRITON_AVAILABLE:
import triton
import triton.language as tl
@triton.jit
def _fg_kernel(
e,
g,
h,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask=mask)
pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_fg_kernel[grid](
e,
g,
h,
n_elements,
BLOCK_SIZE=1024,
)
return h
pass
@triton.jit
def _DWf_DW_dfg_kernel(
DW,
e,
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
f = (se * e).to(dtype)
h = f * g
df = DW * f
dg = DW * g
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
tl.store(g + offsets, de_row, mask=mask) # de
pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_DWf_DW_dfg_kernel[grid](
DW,
e,
g,
n_elements,
BLOCK_SIZE=1024,
)
return DW, e, g
pass
================================================
FILE: extensions/thunder/unsloth/kernels/utils.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from litgpt.constants import _TRITON_AVAILABLE
if _TRITON_AVAILABLE:
import triton
MAX_FUSED_SIZE = 65536 # 2**16
next_power_of_2 = triton.next_power_of_2
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
pass
================================================
FILE: extensions/xla/README.md
================================================
# TPU support
This project utilizes [`Fabric`](https://lightning.ai/docs/fabric/stable), which supports TPUs via [PyTorch XLA](https://github.com/pytorch/xla).
> [!NOTE]
> This guide assumes that you have already set-up your [Google Cloud environment](https://cloud.google.com/run/docs/setup).
To set up a Google Cloud instance with a TPU v4 VM, run the following commands:
```shell
gcloud compute tpus tpu-vm create litgpt --version=tpu-vm-v4-base --accelerator-type=v4-8 --zone=us-central2-b
gcloud compute tpus tpu-vm ssh litgpt --zone=us-central2-b
```
You can also choose a different TPU type. To do so, change the `version`, `accelerator-type`, and `zone` arguments. Find all regions and zones [here](https://cloud.google.com/tpu/docs/regions-zones).
Multihost caveats
TPU v4-8 uses a single host. SSH'ing into the machine and running commands manually will only work when using a single host (1 slice in the TPU pod).
In multi-host environments, such as larger TPU pod slices, it's necessary to launch all commands on all hosts simultaneously to avoid hangs.
For local development, it is advisable to upload a zip file containing all your current changes and execute it inside the VM from your personal computer:
```shell
# Zip the local directory, excluding large directories from the zip. You may want to keep them.
zip -r local_changes.zip . -x ".git/*" "checkpoints/*" "data/*" "out/*"
# Copy the .zip file to the TPU VM
gcloud compute tpus tpu-vm scp --worker=all local_changes.zip "litgpt:~"
# Unzip on each host
gcloud compute tpus tpu-vm ssh litgpt --worker=all --command="cd ~; unzip -q -o local_changes.zip"
# Example of a typical workflow
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash install_dependencies.sh"
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash prepare_checkpoints.sh"
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash run_desired_script.sh"
# This will allow you to kill all python processes on all workers
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="pkill -e python"
```
Notice how the commands to install the environment and prepare checkpoints need to be run on all workers, since the filesystem
for each worker (host) is not shared.
For the rest of this tutorial, it will be assumed that it is being run on a single host for simplicity.
Once inside the machine, clone the repository and install the dependencies:
```shell
git clone https://github.com/Lightning-AI/litgpt
cd litgpt
pip install .
```
Install Optimized BLAS:
```shell
sudo apt update
sudo apt install libopenblas-dev
```
Since LitGPT requires a torch version newer than torch 2.0.0, manually install nightly builds of torch and torch_xla:
```shell
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
```
While computations will run by default using the new PjRT runtime, it is recommended to set the following environment variables:
```shell
export ALLOW_MULTIPLE_LIBTPU_LOAD=1
export PJRT_DEVICE=TPU
```
> [!NOTE]
> An extensive guide on setup and available options can be found [here](https://cloud.google.com/tpu/docs/v4-users-guide).
Since a new machine was created, you may need to download pretrained weights.
They can be copied to the machine using `gcloud compute tpus tpu-vm scp`, or you can follow the steps described in our [downloading guide](../../tutorials/download_model_weights.md).
It is also recommended to set up a persistent disk from which to load checkpoints.
Follow [this guide](https://cloud.google.com/tpu/docs/setup-persistent-disk#setting_up_a_tpu_vm_and_a_persistent_disk) to do so.
Read-write disks are not supported in multihost VM setups, so persistent disks cannot be used to save checkpoints in that case.
Persistent disks can still be useful in read-only mode to load pretrained weights before finetuning or inference.
In multihost settings, FSDP will save checkpoint shards per host and consolidate them into a single checkpoint.
For safekeeping, it is recommended to upload the consolidated checkpoints to a Google Cloud bucket.
Alternatively, you can use the `scp` command to transfer these checkpoints from the TPU VM periodically, although this is not implemented in our scripts.
## Inference
This project provides custom versions of the regular recipes to run with XLA in the `xla` directory.
To generate text, use the following command:
```shell
python3 xla/generate/base.py --prompt "Hello, my name is" --num_samples 3
```
For the first generation, this command will take around 17 seconds as XLA needs to compile the graph.
Subsequent generations will take around 2 seconds.
## Fine-tuning
To get started fine-tuning Falcon 7B with adapter, run the following command:
```shell
python3 xla/scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b
python3 xla/finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true
```
Multihost caveats
This script is configured to save "full" checkpoints, which isn't possible on multihost TPU VMs.
Here's how you can consolidate them together into a single one after training with `state_dict_type="sharded"`:
```shell
path_to_shards="out/adapter/alpaca/lit_model_adapter_finetuned"
mkdir -p $path_to_shards
workers=4 # 4 hosts
for ((i = 0; i < workers; i++)); do
# aggregate all shards locally
gcloud compute tpus tpu-vm scp --worker=$i "litgpt:${path_to_shards}/*" "${path_to_shards}/" --zone us-central2-b
done
# copy all shards to all workers
gcloud compute tpus tpu-vm scp --worker=all ${path_to_shards}/* "litgpt:${path_to_shards}/" --zone us-central2-b
# consolidate the shards in each worker
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="python -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix ${path_to_shards}/checkpoint --ckpt_suffix '_rank-*-of-*.pth' --save_path ${path_to_shards}.pth" --zone us-central2-b
```
Since the TPU VM host RAM is limited (200 GB), we implement a technique to sequentially load and shard the checkpoint that can be enabled by
setting `reduce_cpu_memory_usage_during_load = True`. This is necessary to load falcon-40b.
To generate text with the adapter fine-tuned model weights, use the following command:
```shell
python3 xla/generate/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --adapter_path out/adapter/alpaca/lit_model_adapter_finetuned.pth
```
> **Warning**
> Remember to delete your instance when you are done.
>
> ```shell
> gcloud compute tpus tpu-vm delete litgpt --zone=us-central2-b
> ```
## Computational Performance
Using the [adapter finetuning script](finetune/adapter.py) and XLA's FSDP implementation, a 49.57% MFU was achieved with Falcon 7B on a v4-32 (micro batch size 7), and a 39.67% MFU was achieved with Falcon 40B on a v4-512 (micro batch size 3) at a fixed 1034 maximum sequence length.
Since the TPU VM host has limited system memory (RAM) compared to device memory (HBM), specific techniques were implemented to limit peak RAM usage when loading the model and pretrained weights before sharding, as well as when saving sharded checkpoints.
A v4 chip has 32 GiB HBM, so with 4 devices per host (4 * 32 = 128 GiB HBM), each host has 188 GiB RAM, which is shared across the devices.
Therefore, any RAM allocation over 188/4 = 47 GiB would exceed the host's RAM capacity.
A ~24B parameter model on CPU (with half precision) would be the largest possible model under this setup without the techniques used in our scripts.
================================================
FILE: extensions/xla/__init__
================================================
import sys
from pathlib import Path
# support running without installing as a package, adding extensions to the Python path
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
================================================
FILE: extensions/xla/finetune/__init__
================================================
================================================
FILE: extensions/xla/finetune/adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import sys
import time
from pathlib import Path
from typing import Dict, List, Tuple
import lightning as L
import torch
import torch_xla.core.xla_model as xm
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import XLAFSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor, measure_flops
from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from litgpt.tokenizer import Tokenizer
from litgpt.utils import check_valid_checkpoint_dir, chunked_cross_entropy, estimate_flops, lazy_load, num_parameters
# support running without installing as a package
wd = Path(__file__).parents[3].resolve()
sys.path.append(str(wd))
from xla.generate.base import generate # noqa: E402
from xla.scripts.prepare_alpaca import generate_prompt # noqa: E402
from xla.utils import rank_print, sequential_load_and_fsdp_wrap # noqa: E402
eval_interval = 200
save_interval = 200
eval_iters = 100
eval_max_new_tokens = 100
log_interval = 1
devices = XLAAccelerator.auto_device_count()
# the state of very large models will not fit on the system RAM, this flag can alleviate it by loading it on each rank
# sequentially
reduce_cpu_memory_usage_during_load = False
# Hyperparameters
learning_rate = 3e-3
batch_size = 4
micro_batch_size = batch_size
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
epoch_size = 50000 # train dataset size
num_epochs = 5
max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
weight_decay = 0.02
warmup_steps = 2 * (epoch_size // micro_batch_size) // devices // gradient_accumulation_iters # 2 epochs
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
def setup(
*,
data_dir: Path = Path("data/alpaca"),
checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"),
out_dir: Path = Path("out/adapter/alpaca"),
precision: str = "bf16-true",
) -> None:
if devices > 1:
strategy = XLAFSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full", # change to "sharded" in multi-host environments where the filesystem is not shared
sequential_save=True,
)
else:
strategy = "auto"
logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger)
rank_print(fabric, hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
check_valid_checkpoint_dir(checkpoint_dir)
fabric.seed_everything(1337) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
train_data = torch.load(data_dir / "train.pt")
val_data = torch.load(data_dir / "test.pt")
config = Config.from_name(name=checkpoint_dir.name, adapter_start_layer=0)
checkpoint_path = checkpoint_dir / "lit_model.pth"
rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
if reduce_cpu_memory_usage_during_load:
model = sequential_load_and_fsdp_wrap(fabric, lambda: GPT(config), checkpoint_path)
else:
with fabric.init_module(empty_init=False):
model = GPT(config)
checkpoint = lazy_load(checkpoint_path)
# strict=False because missing keys due to adapter weights not contained in state dict
model.load_state_dict(checkpoint, strict=False)
model = fabric.setup_module(model)
# mark as trainable only after sharding due to https://github.com/pytorch/xla/pull/5484
mark_only_adapter_as_trainable(model)
# these are not correct in the sharding case
rank_print(fabric, f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
rank_print(fabric, f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(trainable_params, lr=learning_rate)
optimizer = fabric.setup_optimizers(optimizer)
fabric.seed_everything(1337 + fabric.global_rank)
train_time = time.perf_counter()
train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir)
rank_print(fabric, f"Training time: {(time.perf_counter() - train_time):.2f}s")
# Save the final checkpoint at the end of training
save_path = out_dir / "lit_model_adapter_finetuned.pth"
save_adapter_checkpoint(fabric, model, save_path)
def train(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
train_data: List[Dict],
val_data: List[Dict],
checkpoint_dir: Path,
out_dir: Path,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length = get_longest_seq_length(train_data)
model.max_seq_length = longest_seq_length
# to avoid recompilation, this script is configured to pad batches to the `longest_seq_length`
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
with torch.device("meta"):
meta_model = GPT(model.config)
mark_only_adapter_as_trainable(meta_model)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `flops_per_batch=estimated_flops` instead
estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size
rank_print(fabric, f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
forward_fn = lambda: meta_model(x) # noqa: F821
loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
rank_print(fabric, f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_t0 = time.perf_counter()
xm.mark_step()
for iter_num in range(1, max_iters + 1):
if step_count <= warmup_steps:
# linear warmup
lr = learning_rate * step_count / warmup_steps
for param_group in optimizer.param_groups:
param_group["lr"] = lr
iter_t0 = time.perf_counter()
input_ids, targets = get_batch(fabric, train_data, longest_seq_length)
is_accumulating = iter_num % gradient_accumulation_iters != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
xm.mark_step()
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / gradient_accumulation_iters)
xm.mark_step()
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
step_count += 1
else:
xm.mark_step()
if iter_num % log_interval == 0:
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0,
batches=iter_num,
samples=iter_num * micro_batch_size,
lengths=iter_num * micro_batch_size * longest_seq_length,
flops=measured_flops * log_interval,
)
throughput.compute_and_log(step=iter_num)
rank_print(
fabric,
f"iter {iter_num} step {step_count}:"
# uncomment to print the loss. this will considerably slow down the iteration times
# + f" loss {loss.item():.4f},"
+ f" iter time: {(t1 - iter_t0) * 1000:.2f}ms"
+ (" (optimizer.step)" if not is_accumulating else ""),
)
if not is_accumulating and step_count % eval_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length)
t1 = time.perf_counter() - t0
rank_print(fabric, f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
fabric.barrier()
if not is_accumulating and step_count % save_interval == 0:
checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
save_adapter_checkpoint(fabric, model, checkpoint_path)
# xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, longest_seq_length: int
) -> torch.Tensor:
rank_print(fabric, "Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
xm.mark_step()
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data, longest_seq_length)
logits = model(input_ids)
xm.mark_step()
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
rank_print(fabric, instruction)
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8)
model.clear_kv_cache()
output = tokenizer.decode(output)
rank_print(fabric, output)
model.train()
return val_loss
def get_batch(fabric: L.Fabric, data: List[Dict], longest_seq_length: int) -> Tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data), (micro_batch_size,))
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
def pad_right(x, pad_id):
# pad right using a fixed longest sequence length to avoid recompilation
n = longest_seq_length - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x, y))
return x, y
def get_longest_seq_length(data: List[Dict]) -> int:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
return max(len(d["input_ids"]) for d in data)
def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
rank_print(fabric, f"Saving adapter weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})
if __name__ == "__main__":
from jsonargparse import CLI
CLI(setup)
================================================
FILE: extensions/xla/generate/__init__
================================================
================================================
FILE: extensions/xla/generate/adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
from pathlib import Path
from typing import Optional
import lightning as L
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies import XLAFSDPStrategy
from litgpt import Tokenizer
from litgpt.adapter import GPT, Block, Config
from litgpt.prompts import Alpaca
from litgpt.utils import check_valid_checkpoint_dir, lazy_load
# support running without installing as a package
wd = Path(__file__).parents[3].resolve()
sys.path.append(str(wd))
from xla.generate.base import generate # noqa: E402
from xla.utils import rank_print # noqa: E402
def setup(
prompt: str = "What food do llamas eat?",
*,
input: str = "",
sys_prompt: Optional[str] = None,
adapter_path: Path = Path("out/adapter/alpaca/lit_model_adapter_finetuned.pth"),
checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"),
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
temperature: float = 0.8,
precision: str = "bf16-true",
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned Adapter model.
See `xla/finetune/adapter.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
sys_prompt: Optional system prompt.
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
`xla/finetune/adapter.py`.
checkpoint_dir: The path to the checkpoint folder with pretrained model weights.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
"""
devices = XLAAccelerator.auto_device_count()
strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else "auto"
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch(main, prompt, input, sys_prompt, adapter_path, checkpoint_dir, max_new_tokens, top_k, temperature)
def main(
fabric: L.Fabric,
prompt: str,
input: str,
sys_prompt: Optional[str],
adapter_path: Path,
checkpoint_dir: Path,
max_new_tokens: int,
top_k: Optional[int],
temperature: float,
) -> None:
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml", adapter_start_layer=0)
checkpoint_path = checkpoint_dir / "lit_model.pth"
rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
rank_print(fabric, f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint)
rank_print(fabric, f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(checkpoint_dir)
# TODO: Load prompt style from checkpoint and apply it here
prompt_style = Alpaca()
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1] if "### Response:" in output else output
output = output.strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
rank_print(
fabric, f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(setup)
================================================
FILE: extensions/xla/generate/base.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
from pathlib import Path
from typing import Optional
import lightning as L
import torch
import torch_xla.core.xla_model as xm
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies import XLAFSDPStrategy
from litgpt import GPT, Config, Tokenizer
from litgpt.model import Block
from litgpt.utils import check_valid_checkpoint_dir, lazy_load
# support running without installing as a package
wd = Path(__file__).parents[3].resolve()
sys.path.append(str(wd))
from xla.utils import rank_print # noqa: E402
# xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor
@torch.no_grad()
def generate(
model: GPT,
idx: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the token is triggered.
"""
T = idx.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
# TODO: FSDP has an internal broadcasting issue, so we are forced to have this be of length 1 until it's fixed
input_pos = torch.tensor([0], device=device)
xm.mark_step()
# generate up to a fixed number of tokens
for _ in range(max_returned_tokens):
x = idx.index_select(0, input_pos).view(1, -1)
# forward
logits = model(x, input_pos)
logits = logits[0, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
# advance
input_pos = input_pos[-1:] + 1
xm.mark_step()
# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)
# if token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos] # include the EOS token
return idx
def setup(
prompt: str = "What food do llamas eat?",
*,
num_samples: int = 1,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/tiiuae/falcon-7b"),
precision: str = "bf16-true",
) -> None:
"""Generates text samples based on a pre-trained model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
precision: Indicates the Fabric precision setting to use.
"""
devices = XLAAccelerator.auto_device_count()
strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else "auto"
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch(main, prompt, num_samples, max_new_tokens, top_k, temperature, checkpoint_dir)
def main(
fabric: L.Fabric,
prompt: str,
num_samples: int,
max_new_tokens: int,
top_k: Optional[int],
temperature: float,
checkpoint_dir: Path,
) -> None:
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
rank_print(fabric, f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
rank_print(fabric, f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint.get("model", checkpoint))
rank_print(fabric, f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(checkpoint_dir)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
L.seed_everything(1234)
for i in range(num_samples):
with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
rank_print(
fabric,
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
file=sys.stderr,
)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(setup)
================================================
FILE: extensions/xla/scripts/__init__
================================================
================================================
FILE: extensions/xla/scripts/prepare_alpaca.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
import json
from pathlib import Path
from typing import Optional
import torch
import yaml
from lightning_utilities.core.imports import RequirementCache
from torch.utils.data import random_split
from tqdm import tqdm
from litgpt.tokenizer import Tokenizer
from litgpt.utils import CLI
def prepare(
destination_path: Path = Path("data/alpaca"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
val_split_fraction: float = 0.03865, # to get exactly 2000 validation samples,
seed: int = 42,
mask_inputs: bool = False, # as in alpaca-lora
data_file_name: str = "alpaca_data_cleaned_archive.json",
data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json",
ignore_index: int = -100,
max_seq_length: Optional[int] = None,
) -> None:
"""Prepare the Alpaca dataset for instruction tuning.
The output is a training and test dataset saved as `train.pt` and `test.pt`,
which stores the preprocessed and tokenized prompts and labels.
"""
if max_seq_length is None:
with open(checkpoint_dir / "model_config.yaml", encoding="utf-8") as file:
config = yaml.safe_load(file)
max_seq_length = config["block_size"]
destination_path.mkdir(parents=True, exist_ok=True)
data_file_path = destination_path / data_file_name
print("Loading data file...")
download_if_missing(data_file_path, data_file_url)
with open(data_file_path, encoding="utf-8") as file:
data = json.load(file)
print("Loading tokenizer...")
tokenizer = Tokenizer(checkpoint_dir)
# Partition the dataset into train and test
train_set, test_set = random_split(
data, [1.0 - val_split_fraction, val_split_fraction], generator=torch.Generator().manual_seed(seed)
)
train_set, test_set = list(train_set), list(test_set)
print(f"train has {len(train_set):,} samples")
print(f"test has {len(test_set):,} samples")
print("Processing train split ...")
train_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(train_set)
]
torch.save(train_set, destination_path / "train.pt")
print("Processing test split ...")
test_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(test_set)
]
torch.save(test_set, destination_path / "test.pt")
def download_if_missing(file_path: Path, file_url: str) -> None:
"""Downloads the raw json data file and saves it in the given destination."""
if file_path.exists() and file_path.stat().st_size > 0:
return
requests_available = RequirementCache("requests")
if not requests_available:
raise ModuleNotFoundError(str(requests_available))
import requests
with open(file_path, "w", encoding="utf-8") as f:
f.write(requests.get(file_url).text)
def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict:
"""Processes a single sample.
Each sample in the dataset consists of:
- instruction: A string describing the task
- input: A string holding a special input value for the instruction.
This only applies to some samples, and in others this is empty.
- output: The response string
This function processes this data to produce a prompt text and a label for
supervised training. The prompt text is formed as a single message including both
the instruction and the input. The label/target is the same message but with the
response attached.
Finally, both the prompt and the label get tokenized. If desired, all tokens
in the label that correspond to the original input prompt get masked out (default).
"""
full_prompt = generate_prompt(example)
full_prompt_and_response = full_prompt + example["output"]
encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)
encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)
# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_full_prompt_and_response.clone()
if mask_inputs:
labels[: len(encoded_full_prompt)] = ignore_index
return {**example, "input_ids": encoded_full_prompt_and_response, "labels": labels}
def generate_prompt(example: dict) -> str:
"""Generates a standardized message to prompt the model with an instruction, optional input and a
'response' field."""
if example["input"]:
return (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
)
return (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Response:"
)
if __name__ == "__main__":
CLI(prepare)
================================================
FILE: extensions/xla/utils.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import itertools
from functools import partial
from pathlib import Path
from typing import Any, Callable
import lightning as L
import torch
from lightning.fabric.strategies.xla_fsdp import XLAFSDPStrategy, _activation_checkpointing_auto_wrapper
from lightning_utilities.core.rank_zero import rank_prefixed_message
from litgpt import GPT
def rank_print(fabric: L.Fabric, message: object, *, flush: bool = True, **kwargs: Any) -> None:
if fabric.local_rank == 0:
message = str(message)
# let each host print, but only on rank 0
message = rank_prefixed_message(message, fabric.global_rank)
# TPU VM will only print when the script finishes if `flush=False`
print(message, flush=flush, **kwargs)
def materialize_parameters(module: torch.nn.Module, device: torch.device) -> None:
for module_name, module in module.named_modules():
if any(
param.is_meta for param in itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))
):
module.to_empty(device=device, recurse=False)
module.reset_parameters()
def sequential_load_and_fsdp_wrap(
fabric: L.Fabric, get_model: Callable[[], GPT], checkpoint_path: Path
) -> torch.nn.Module:
assert fabric._launched
# similar logic could be implemented for regular FSDP, but this implementation is specific to XLAFSDP
assert isinstance(fabric.strategy, XLAFSDPStrategy)
with fabric.init_module(empty_init=False), torch.device("meta"):
model = get_model()
# TODO: this could be made faster by broadcasting in separate process groups for each host
if fabric.local_rank == 0:
# load the full checkpoint on a single rank to limit the system memory usage
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=False) # mmap=True hangs
else:
# XLA cannot broadcast different number of tensors or different shapes in each rank. To get around this
# limitation, we need to load the checkpoint on meta device to get the correct number of tensors and materialize
# them as necessary
state_dict = torch.load(checkpoint_path, map_location="meta", mmap=False)
fsdp_kwargs = fabric.strategy._parse_fsdp_kwargs()
if "auto_wrapper_callable" in fsdp_kwargs:
# includes activation checkpointing if configured
wrap = fsdp_kwargs.pop("auto_wrapper_callable")
else:
wrap = partial(_activation_checkpointing_auto_wrapper, set())
fsdp_kwargs.pop("auto_wrap_policy", None) # this needs to be removed or else root wrapping would error
for i, block in enumerate(model.transformer.h):
rank_print(fabric, f"Broadcasting transformer block {i}")
# get the relevant piece of the state dict
to_load = {}
for param_name, _ in block.named_parameters():
if (key := f"transformer.h.{i}.{param_name}") not in state_dict:
continue
param = state_dict.pop(key)
if not param.is_meta:
to_load[param_name] = param
else:
# materialize this parameter for broadcast to work
to_load[param_name] = torch.empty_like(param, device="cpu")
to_load = fabric.broadcast(to_load)
rank_print(fabric, f"Loading transformer block {i}")
keys = block.load_state_dict(to_load, strict=False, assign=True)
assert not keys.unexpected_keys
# materialize any leftover meta parameters, regular FSDP does it automatically
materialize_parameters(block, torch.device("cpu")) # init on CPU, FSDP will shard and move it
# XLA FSDP only supports fp32 parameters. If the checkpoint had a different dtype, this needs to be converted
# since we are loading with assign=True
block = block.to(torch.float32)
# shard the block
rank_print(fabric, f"Wrapping transformer block {i}")
wrapped_block = wrap(block, **fsdp_kwargs)
model.transformer.h[i] = wrapped_block
# load the rest of the state_dict, this assumes that all keys need to be loaded
# an alternative technique would be to do load the rest of the state dict at once, but we want to materialize
# and move the params to the xla device to reduce the system memory usage
for key in list(state_dict):
rank_print(fabric, f"Loading {key}")
param = state_dict.pop(key)
if param.is_meta:
# materialize this parameter for broadcast to work
param = torch.empty_like(param, device="cpu")
param = fabric.broadcast(param)
param = param.to(device=fabric.device, dtype=torch.float32)
keys = model.load_state_dict({key: param}, strict=False, assign=True)
assert not keys.unexpected_keys
assert not state_dict
# materialize any leftover meta parameters, regular FSDP does it automatically
rank_print(fabric, "Materializing leftover parameters")
materialize_parameters(model, fabric.device)
return model
================================================
FILE: litgpt/__init__.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import logging
import re
from litgpt.api import LLM
from litgpt.config import Config
from litgpt.model import GPT # needs to be imported before config
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
pattern = re.compile(".*Profiler function .* will be ignored")
logging.getLogger("torch._dynamo.variables.torch").addFilter(lambda record: not pattern.search(record.getMessage()))
# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
__all__ = ["LLM", "GPT", "Config", "PromptStyle", "Tokenizer"]
================================================
FILE: litgpt/__main__.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import warnings
import torch
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
from litgpt.chat.base import main as chat_fn
from litgpt.deploy.serve import run_server as serve_fn
from litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn
from litgpt.finetune.adapter import setup as finetune_adapter_fn
from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn
from litgpt.finetune.full import setup as finetune_full_fn
from litgpt.finetune.lora import setup as finetune_lora_fn
from litgpt.generate.adapter import main as generate_adapter_fn
from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn
from litgpt.generate.base import main as generate_base_fn
from litgpt.generate.full import main as generate_full_fn
from litgpt.generate.sequentially import main as generate_sequentially_fn
from litgpt.generate.speculative_decoding import main as generate_speculatively_fn
from litgpt.generate.tp import main as generate_tp_fn
from litgpt.parser_config import parser_commands
from litgpt.pretrain import setup as pretrain_fn
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint as convert_hf_checkpoint_fn
from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint as convert_lit_checkpoint_fn
from litgpt.scripts.convert_pretrained_checkpoint import (
convert_pretrained_checkpoint as convert_pretrained_checkpoint_fn,
)
from litgpt.scripts.download import download_from_hub as download_fn
from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn
PARSER_DATA = {
"download": download_fn,
"chat": chat_fn,
"finetune": finetune_lora_fn,
"finetune_lora": finetune_lora_fn,
"finetune_full": finetune_full_fn,
"finetune_adapter": finetune_adapter_fn,
"finetune_adapter_v2": finetune_adapter_v2_fn,
"pretrain": pretrain_fn,
"generate": generate_base_fn,
"generate_full": generate_full_fn,
"generate_adapter": generate_adapter_fn,
"generate_adapter_v2": generate_adapter_v2_fn,
"generate_sequentially": generate_sequentially_fn,
"generate_speculatively": generate_speculatively_fn,
"generate_tp": generate_tp_fn,
"convert_to_litgpt": convert_hf_checkpoint_fn,
"convert_from_litgpt": convert_lit_checkpoint_fn,
"convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn,
"merge_lora": merge_lora_fn,
"evaluate": evaluate_fn,
"serve": serve_fn,
}
def _check_commands():
assert set(parser_commands()) == set(PARSER_DATA.keys()), (
"PARSER_DATA has to be kept in sync with litgpt.parser_config.parser_commands()"
)
def main() -> None:
_check_commands()
set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
# PyTorch bug that raises a false-positive warning
# More info: https://github.com/Lightning-AI/litgpt/issues/1561
warning_message = r"The epoch parameter in `scheduler.step\(\)` was not necessary and is being deprecated.*"
warnings.filterwarnings(
action="ignore", message=warning_message, category=UserWarning, module=r".*torch\.optim\.lr_scheduler.*"
)
torch.set_float32_matmul_precision("high")
CLI(PARSER_DATA)
if __name__ == "__main__":
main()
================================================
FILE: litgpt/adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation of the paper:
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199
Port for LitGPT
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from typing_extensions import Self
from litgpt.config import Config as BaseConfig
from litgpt.model import GPT as BaseModel
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
@dataclass
class Config(BaseConfig):
adapter_prompt_length: int = 10
adapter_start_layer: int = 2
class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.max_seq_length = self.config.block_size
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, CausalSelfAttention):
module.reset_parameters()
class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
# kv cache for inference
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
y = super().scaled_dot_product_attention(q, k, v, mask)
if self.block_idx < self.config.adapter_start_layer:
return y
aT = self.config.adapter_prompt_length
if self.adapter_kv_cache is not None:
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
# are the same every call
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
if self.config.n_query_groups != 1:
# for MHA this is a no-op
ak = ak.repeat_interleave(q_per_kv, dim=2)
av = av.repeat_interleave(q_per_kv, dim=2)
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
self.adapter_kv_cache = (ak, av)
T = q.size(2)
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
ay = super().scaled_dot_product_attention(q, ak, av, amask)
return y + self.gating_factor * ay
def reset_parameters(self) -> None:
if hasattr(self, "gating_factor"):
torch.nn.init.zeros_(self.gating_factor)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with older checkpoints."""
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def mark_only_adapter_as_trainable(model: GPT) -> None:
"""Sets `requires_grad=False` for all non-adapter weights."""
for name, param in model.named_parameters():
param.requires_grad = adapter_filter(name, param)
def adapter_filter(key: str, value: Any) -> bool:
return "adapter_wte" in key or "gating_factor" in key
================================================
FILE: litgpt/adapter_v2.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation of the paper:
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
https://arxiv.org/abs/2304.15010
Port for LitGPT
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type
import torch
import torch.nn as nn
from typing_extensions import Self
import litgpt
from litgpt.adapter import GPT as BaseModel
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.adapter import Config as BaseConfig
from litgpt.model import Block as BaseBlock
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights
@dataclass
class Config(BaseConfig):
@property
def mlp_class(self) -> Type:
return getattr(litgpt.adapter_v2, self.mlp_class_name)
def adapter_filter(key: str, value: Any) -> bool:
adapter_substrings = (
# regular adapter v1 parameters
"adapter_wte",
"gating_factor",
# adapter v2: new bias and scale used in Linear
"adapter_scale",
"adapter_bias",
# adapter v2: Norm parameters are now trainable
"norm_1",
"norm_2",
"ln_f",
)
return any(s in key for s in adapter_substrings)
class AdapterV2Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.adapter_scale * (self.linear(x) + self.adapter_bias)
def reset_parameters(self) -> None:
nn.init.zeros_(self.adapter_bias)
nn.init.ones_(self.adapter_scale)
class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.max_seq_length = self.config.block_size
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, AdapterV2Linear):
module.reset_parameters()
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
self.mlp = config.mlp_class(config)
class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
# Copy&paste from :class:`model.CausalSelfAttention`
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
# key, query, value projections for all heads, but in a batch
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
# For compatibility with older checkpoints
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class GptNeoxMLP(litgpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc.weight": "fc.linear.weight",
"fc.bias": "fc.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class LLaMAMLP(litgpt.model.LLaMAMLP):
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
nn.Module.__init__(self)
self.intermediate_size = intermediate_size or config.intermediate_size
self.fc_1 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)
self.fc_2 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(self.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc_1.weight": "fc_1.linear.weight",
"fc_1.bias": "fc_1.linear.bias",
"fc_2.weight": "fc_2.linear.weight",
"fc_2.bias": "fc_2.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)
class LLaMAMoE(litgpt.model.LLaMAMoE):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)
self.experts = nn.ModuleList(
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"gate.weight": "gate.linear.weight"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
"""Sets requires_grad=False for all non-adapter weights"""
for name, param in model.named_parameters():
param.requires_grad = adapter_filter(name, param)
================================================
FILE: litgpt/api.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
#
# This file implements the LitGPT Python API
import sys
import time
from pathlib import Path
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import lightning as L
import numpy as np
import torch
from lightning.fabric.accelerators import CUDAAccelerator
from lightning.fabric.plugins import BitsandbytesPrecision
from tqdm import tqdm
from litgpt.chat.base import generate as stream_generate_fn
from litgpt.config import Config, name_to_config
from litgpt.generate.base import generate as generate_fn
from litgpt.generate.sequentially import sequential
from litgpt.generate.tp import tensor_parallel
from litgpt.model import GPT
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style, save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_nvlink_connectivity,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
save_config,
)
class LLM(torch.nn.Module):
def __init__(
self,
model: GPT,
preprocessor=None,
prompt_style: PromptStyle = None,
devices: Union[int, List[int]] = None,
config: Config = None,
checkpoint_dir: Path = None,
fabric: L.Fabric = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
kv_cache_initialized: bool = False,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None,
) -> None:
super().__init__()
self.model = model
self.preprocessor = preprocessor
self.devices = devices
self.prompt_style = prompt_style
self.config = config
self.checkpoint_dir = checkpoint_dir
self.fabric = fabric
self.generate_strategy = generate_strategy
self.kv_cache_initialized = kv_cache_initialized
self.fixed_kv_cache_size = fixed_kv_cache_size
self.prev_generated_seq_length = 0
"""
LLM model class for inference, pretraining, and finetuning.
Example:
from litgpt.api import LLM
llm = LLM.load("microsoft/phi-2")
text = llm.generate("What do Llamas eat?", top_k=1)
print(text)
"""
@property
def tokenizer(self):
return self.preprocessor.tokenizer
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)
def forward(
self,
input_ids: torch.Tensor,
target_ids: Optional[torch.Tensor] = None,
loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
logits = self.model(input_ids)
if target_ids is not None:
if loss_fn is None:
loss_fn = chunked_cross_entropy
loss = loss_fn(logits[..., :-1, :], target_ids[..., 1:])
return logits, loss
else:
return logits
def trainer_setup(self, trainer_ckpt: Optional[Path] = None) -> None:
"""Initializes the model checkpoint for PyTorch Lightning Trainer contexts"""
self.model = GPT(self.config)
if trainer_ckpt is not None:
# strip the object name key from the state_dict
state_dict = torch.load(trainer_ckpt, weights_only=True)["state_dict"]
first_key = next(iter(state_dict))
prefix = first_key.split(".")[0] + "."
keys_to_modify = [key for key in state_dict if key.startswith(prefix)]
for key in keys_to_modify:
new_key = key.replace(prefix, "", 1)
state_dict[new_key] = state_dict.pop(key)
self.load_state_dict(state_dict, strict=True)
elif self.checkpoint_dir is not None:
state_dict = torch.load(self.checkpoint_dir / "lit_model.pth", weights_only=False)
self.load_state_dict(state_dict, strict=False)
else:
raise ValueError(
"No checkpoint found. Either provide a valid path via `trainer_ckpt` "
"or ensure that `self.checkpoint_dir` points to a folder containing a `lit_model.pth` weight file."
)
def save(self, out_dir: Optional[Path] = None, prompt_style: Optional[PromptStyle] = None) -> None:
out_dir = Path(out_dir)
save_path = out_dir / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
if prompt_style is None:
prompt_style = PromptStyle.from_config(self.config)
if self.fabric is None:
torch.save(self.state_dict(), save_path)
else:
self.fabric.save(save_path, self.state_dict())
if self.fabric is None or self.fabric.global_rank == 0:
# If initialization a model with random weights, the checkpoint dir can be none
if self.checkpoint_dir is not None:
copy_config_files(Path(self.checkpoint_dir), save_path.parent)
else:
save_config(self.config, out_dir)
save_prompt_style(prompt_style, save_path.parent)
@classmethod
def load(
cls,
model: str,
init: Optional[Literal["pretrained", "random"]] = "pretrained",
tokenizer_dir: Optional[Path] = None,
access_token: Optional[str] = None,
distribute: Optional[Literal["auto"]] = "auto",
) -> "LLM":
"""
Loads the LLM from a local directory or model hub.
Arguments
model: A local path to a directory containing the model weights or a valid model name.
You can get a list of valid model names via the `litgpt download list` command line argument.
init: If "pretrained" (default), downloads the model from the HF Hub if a local model can't be found at the `model`
directory name; otherwise loads the model from the local directory.
If "random", initializes the `model` with random weights.
tokenizer_dir: An optional tokenizer directory if `model` is not a checkpoint directory, or if a user
wants to use a different tokenizer instead.
access_token: Optional API token to access models with restrictions when using `init="pretrained"`.
distribute: If "auto" (default), initializes the model on a single GPU if available and otherwise on the CPU.
To have more control over the model distribution strategy and utilize multiple GPUs, you can set
`llm = LLM.load(..., distribute=None)` and call `llm.distribute(...)` manually.
"""
allowed_init = {"pretrained", "random"}
if init == "pretrained":
checkpoint_dir = auto_download_checkpoint(
model_name=model, access_token=access_token, ignore_tokenizer_files=tokenizer_dir is not None
)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
elif init == "random":
checkpoint_dir = None
try:
config = Config.from_name(model)
except ValueError:
print(f"Model name {model} is not supported.\n")
available_models = "\n".join(sorted(name_to_config))
print(f"Available values:\n{available_models}")
return
else:
raise ValueError(f"Invalid init option: {init}. Must be one of {allowed_init}")
torch.set_float32_matmul_precision("high")
if tokenizer_dir is not None:
tokenizer_dir = extend_checkpoint_dir(Path(tokenizer_dir))
tokenizer = Tokenizer(tokenizer_dir)
elif checkpoint_dir is not None:
tokenizer = Tokenizer(checkpoint_dir)
else:
raise ValueError("Provide a path to a tokenizer directory via the `tokenizer_dir` setting.")
if checkpoint_dir is not None:
prompt_style = (
load_prompt_style(checkpoint_dir)
if has_prompt_style(checkpoint_dir)
else PromptStyle.from_config(config)
)
else:
prompt_style = PromptStyle.from_config(config)
if distribute == "auto":
if torch.cuda.is_available():
accelerator = "cuda"
elif torch.backends.mps.is_available():
accelerator = "mps"
else:
accelerator = "cpu"
fabric = L.Fabric(
accelerator=accelerator,
devices=1,
precision=get_default_supported_precision(training=False),
)
with fabric.init_module(empty_init=False):
model = GPT(config)
model.eval()
preprocessor = Preprocessor(tokenizer, device=fabric.device)
if checkpoint_dir is not None:
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup_module(model)
else:
preprocessor = Preprocessor(tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")
model = None
fabric = None
return cls(
model=model,
preprocessor=preprocessor,
prompt_style=prompt_style,
config=config,
checkpoint_dir=checkpoint_dir,
fabric=fabric,
generate_strategy=None,
kv_cache_initialized=False,
fixed_kv_cache_size=False,
)
def distribute(
self,
accelerator: Literal["cpu", "cuda", "auto"] = "auto",
devices: Union[int, Literal["auto"]] = "auto",
precision: Optional[Any] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None,
) -> None:
"""
Moves the model onto specified devices for single-GPU or multi-GPU inference
accelerator: Which device type to load the model on ("cpu", "gpu", "mps", "cuda", or "auto")
devices: The number of devices (1, 2, etc.) or "auto", which uses all available devices
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
For instance, "32-true", "16-mixed", "16-true", "bf16-mixed", "bf16-true".
For more details, see https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
generate_strategy: Whether to use a sequential model generation strategy. The "sequential" settings allows running
models that wouldn't fit in a single card by partitioning the transformer blocks across
all devices and running them sequentially. Sequential generation may be slower but allows using larger models.
Note that sequential generation sets `fixed_kv_cache_size="max_model_supported"`. You can set it to a lower integer
value, `fixed_kv_cache_size=256` to reduce memory. The `fixed_kv_cache_size` value determines the maximum number
of tokens that can be returned via `llm.generate(...)`.
fixed_kv_cache_size: If set to an integer value or "max_model_supported" is set, the kv-cache won't be resized dynamically
during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy="sequential`.
Note that the chosen `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned in `llm.generate(...)`.
"""
if self.checkpoint_dir is None:
raise NotImplementedError(
"The LLM was initialized with init='random' but .distribute() "
"currently only supports pretrained weights."
)
allowed_accelerators = {"cpu", "gpu", "cuda", "mps", "auto"}
if accelerator not in allowed_accelerators:
raise ValueError(f"Invalid accelerator: {accelerator}. Must be one of {allowed_accelerators}.")
if accelerator == "auto":
if torch.cuda.is_available():
accelerator = "cuda"
elif torch.backends.mps.is_available():
accelerator = "mps"
else:
accelerator = "cpu"
if generate_strategy in ("sequential", "tensor_parallel") and accelerator not in ("cuda", "gpu"):
raise NotImplementedError(
f"generate_strategy='{generate_strategy}' is only supported for accelerator='cuda'|'gpu'."
)
if devices == "auto":
if generate_strategy in ("sequential", "tensor_parallel"):
total_devices = CUDAAccelerator.auto_device_count()
else:
total_devices = 1
elif isinstance(devices, int) and accelerator == "cuda":
use_devices = calculate_number_of_devices(devices)
total_devices = CUDAAccelerator.auto_device_count()
if use_devices > total_devices:
raise ValueError(
f"You selected more devices ({use_devices}) than available in your system ({total_devices})."
)
else:
total_devices = use_devices
if total_devices > 1 and generate_strategy not in ("sequential", "tensor_parallel"):
raise NotImplementedError(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."
)
elif accelerator == "cpu" or accelerator == "mps":
total_devices = 1
else:
raise ValueError(f"devices argument must be an integer or 'auto', got {devices}")
print(f"Using {total_devices} device(s)", file=sys.stderr)
if precision is None:
precision = get_default_supported_precision(training=False)
print("Precision set", file=sys.stderr)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("The combination of quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
# set "ddp" as the strategy for the launching functionality, but there's no data-parallelism
if generate_strategy != "tensor_parallel":
fabric = L.Fabric(
accelerator=accelerator,
devices=1, # Otherwise sequential wouldn't work, see litgpt/generate/sequentially.py
# devices=devices,
precision=precision,
plugins=plugins,
)
else:
fabric = L.Fabric(
accelerator=accelerator, devices=total_devices, strategy="ddp", precision=precision, plugins=plugins
)
if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1:
check_nvlink_connectivity(fabric)
fabric.launch()
print("Fabric launched", file=sys.stderr)
self.kv_cache_initialized = False
if generate_strategy is None:
with fabric.init_module(empty_init=(total_devices > 1)):
model = GPT(self.config)
model.eval()
if self.checkpoint_dir is not None:
load_checkpoint(fabric, model, self.checkpoint_dir / "lit_model.pth")
model = fabric.setup_module(model)
if fixed_kv_cache_size is not None:
if fixed_kv_cache_size is None or fixed_kv_cache_size == "max_model_supported":
kv_cache_size = model.max_seq_length
else:
kv_cache_size = fixed_kv_cache_size
model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)
self.kv_cache_initialized = True
self.fixed_kv_cache_size = fixed_kv_cache_size
elif generate_strategy in ("sequential", "tensor_parallel"):
with fabric.init_tensor(), torch.device("meta"):
model = GPT(self.config)
model.eval()
if generate_strategy == "sequential":
state_dict = torch.load(
str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False
)
model.load_state_dict(state_dict, assign=True)
model = fabric.setup_module(model, move_to_device=False)
if fixed_kv_cache_size is None:
fixed_kv_cache_size = "max_model_supported"
if fixed_kv_cache_size == "max_model_supported":
kv_cache_size = model.max_seq_length
else:
kv_cache_size = fixed_kv_cache_size
model = sequential(model, fabric.device, kv_cache_size, total_devices)
self.fixed_kv_cache_size = fixed_kv_cache_size
elif generate_strategy == "tensor_parallel":
if fabric.global_rank == 0:
pbar = tqdm(total=fabric.world_size, desc="Loading model weights")
for rank in range(fabric.world_size):
if fabric.global_rank == rank:
state_dict = torch.load(
str(self.checkpoint_dir / "lit_model.pth"),
mmap=True,
map_location="cpu",
weights_only=False,
)
model.load_state_dict(state_dict, assign=True)
# cannot use `.setup_module` because it will wrap with DDP
model = fabric._precision.convert_module(model)
model = tensor_parallel(fabric, model)
with fabric.init_tensor():
if fixed_kv_cache_size is None:
fixed_kv_cache_size = "max_model_supported"
if fixed_kv_cache_size == "max_model_supported":
kv_cache_size = model.max_seq_length
else:
kv_cache_size = fixed_kv_cache_size
model.max_seq_length = kv_cache_size
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
model = fabric.to_device(model)
fabric.barrier()
if fabric.global_rank == 0:
pbar.update(1)
if fabric.global_rank == 0:
pbar.close()
self.kv_cache_initialized = True
else:
raise ValueError(f"Unsupported generate_strategy: {generate_strategy}")
self.model = model
self.fabric = fabric
self.preprocessor.device = fabric.device
@torch.inference_mode()
def generate(
self,
prompt: str,
sys_prompt: Optional[str] = None,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
return_as_token_ids: bool = False,
stream: bool = False,
) -> Union[str, torch.Tensor]:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
Arguments:
model: The model to use.
prompt: The prompt string to use for generating the samples.
sys_prompt: The system prompt string to use for generating the samples.
The system prompt allows the user to provide additional instructions to shape all responses by providing additional context, behavioral guidelines, style, and constraints.
max_new_tokens: The maximum number of new tokens to return.
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
return_as_token_ids: If True, returns the token IDs as a torch.Tensor. Otherwise, returns the decoded text as a string.
stream: If True, returns a generator that yields tokens as they are generated.
At the moment, this setting is slower and may use more memory than the non-streaming version.
We plan to resolve this in the future.
"""
if self.model is None:
raise AttributeError(
"The model is not initialized yet; use the .distribute() "
"or .trainer_setup() method to initialize the model."
)
input_ids = self._text_to_token_ids(prompt, sys_prompt)
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens
if not self.kv_cache_initialized:
if self.fabric is not None:
device = self.fabric.device
else:
device = self.preprocessor.device
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)
self.kv_cache_initialized = True
# Dynamically grow the kv cache size if necessary
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
tmp_device = self.model.mask_cache.device
self.model.clear_kv_cache()
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)
else:
for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()
self.prev_generated_seq_length = max_returned_tokens
self.model.eval()
def iterator():
outputs = stream_generate_fn(
model=self.model,
prompt=input_ids,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=([self.preprocessor.tokenizer.eos_id],),
)
if return_as_token_ids:
yield from outputs
else:
for output in outputs:
yield self.preprocessor.decode(output)
return
if stream:
outputs = iterator()
else:
outputs = generate_fn(
model=self.model,
prompt=input_ids,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_id=self.preprocessor.tokenizer.eos_id,
include_prompt=False,
)
if stream:
return outputs
elif return_as_token_ids:
return outputs
else:
return self.preprocessor.decode(outputs)
def _text_to_token_ids(self, prompt: str, sys_prompt: Optional[str] = None) -> torch.Tensor:
"""Utility method to convert a prompt text to token IDs"""
prompt = self.prompt_style.apply(prompt, sys_prompt=sys_prompt)
input_ids = self.preprocessor.encode(prompt)
return input_ids
def benchmark(self, num_iterations=1, **kwargs):
"""
A wrapper around the .generate() method to calculate runtime performance.
Arguments:
num_iterations: How often the `.generate()` call is repeated.
kwargs: Keyword arguments that are passed to the .generate() method.
"""
benchmark_dict = {}
for i in range(num_iterations):
time_to_first_token = None
t0 = time.perf_counter()
outputs = self.generate(**kwargs)
if kwargs.get("stream", False):
gen_outputs = []
for e in outputs:
if time_to_first_token is None:
t1 = time.perf_counter()
time_to_first_token = t1 - t0
gen_outputs.append(e)
outputs = "".join(gen_outputs)
else:
outputs = self.generate(
**kwargs,
)
benchmark_dict.setdefault("Seconds total", []).append(time.perf_counter() - t0)
benchmark_dict.setdefault("Seconds to first token", []).append(time_to_first_token)
tokens_generated = self.preprocessor.encode(outputs).size(0)
benchmark_dict.setdefault("Tokens generated", []).append(tokens_generated)
benchmark_dict.setdefault("Inference speed in tokens/sec", []).append(
benchmark_dict["Tokens generated"][-1] / benchmark_dict["Seconds total"][-1]
)
if self.fabric is not None and self.fabric.device.type == "cuda":
benchmark_dict.setdefault("Total GPU memory allocated in GB", []).append(
torch.cuda.max_memory_allocated() / 1e9
)
return outputs, benchmark_dict
class Preprocessor:
"""
Preprocessor class for tokenization and de-tokenization.
"""
def __init__(self, tokenizer: Tokenizer, device: str = "cpu") -> None:
self.tokenizer = tokenizer
self.device = device
def encode(self, text: str) -> torch.Tensor:
return self.tokenizer.encode(text, device=self.device)
def decode(self, token_ids: torch.Tensor) -> str:
return self.tokenizer.decode(token_ids)
def calculate_number_of_devices(devices):
"""
Utility function to calculate the number of devices.
"""
num_devices = devices if isinstance(devices, int) else len(devices) if isinstance(devices, list) else 0
return num_devices
def benchmark_dict_to_markdown_table(data):
"""
Converts .benchmark() outputs to a markdown table
"""
markdown_table = (
"| Metric | Mean | Std Dev |\n"
)
markdown_table += (
"|-------------------------------------|-----------------------------|-----------------------------|\n"
)
for key, values in data.items():
mean_value = np.mean(values)
std_dev_value = np.std(values, ddof=1)
formatted_mean = f"{mean_value:.2f}"
formatted_std_dev = f"{std_dev_value:.2f}"
markdown_table += f"| {key.ljust(35)} | {formatted_mean.ljust(27)} | {formatted_std_dev.ljust(27)} |\n"
return markdown_table
def pull_request_benchmark_util(model_name="microsoft/phi-2", num_iterations=6):
def print_table(header, data):
print(f"\n### {header}\n")
markdown_table = (
f"| Metric | First Iteration | "
f"Iter 2-{num_iterations} Mean | Iter 2-{num_iterations} Standard Dev. |\n"
f"|--------------------------------------|-----------------|"
f"-------------------|-------------------------|\n"
)
for key, value in data.items():
first_iteration = f"{value[0]:.2f}" if value[0] is not None else "N/A"
clean_values = [v for v in value[1:] if v is not None]
if clean_values:
mean_value = np.mean(clean_values)
std_dev_value = np.std(clean_values, ddof=1)
mean_str = f"{mean_value:.2f}"
std_dev_str = f"{std_dev_value:.2f}"
else:
mean_str = "N/A"
std_dev_str = "N/A"
markdown_table += f"| {key:<36} | {first_iteration:<15} | {mean_str:<17} | {std_dev_str:<23} |\n"
print(markdown_table)
import subprocess
try:
g_hash = subprocess.run(
["git", "rev-parse", "--short", "HEAD"], capture_output=True, text=True, check=True
).stdout.strip()
print(f"Git Commit Hash: {g_hash}")
except subprocess.CalledProcessError:
print("Git Commit Hash: N/A")
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# 1st table
llm = LLM.load(
model=model_name,
)
text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1)
print_table(f"Defaults ({model_name}), 1st time", bench_d)
del llm
# 2nd table
llm = LLM.load(
model=model_name,
)
text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1)
print_table(f"Defaults ({model_name}), 2nd time", bench_d)
del llm
# 3rd table
llm = LLM.load(
model=model_name,
)
text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1, stream=True)
print_table("stream=True", bench_d)
del llm
# 4th table
llm = LLM.load(model=model_name, distribute=None)
llm.distribute(fixed_kv_cache_size=500)
text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt="What do llamas eat?", top_k=1, stream=True)
print_table("stream=True + fixed_kv_cache=500", bench_d)
================================================
FILE: litgpt/args.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import math
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Union
@dataclass
class TrainArgs:
"""Training-related arguments"""
save_interval: Optional[int] = 1000
"""Number of optimizer steps between saving checkpoints"""
log_interval: int = 1
"""Number of iterations between logging calls"""
global_batch_size: int = 64
"""Number of samples between optimizer steps across data-parallel ranks"""
micro_batch_size: int = 4
"""Number of samples per data-parallel rank"""
lr_warmup_steps: Optional[int] = 100
"""Number of iterations with learning rate warmup active"""
lr_warmup_fraction: Optional[float] = None
"""The fraction of an epoch to use for learning rate warmup"""
epochs: Optional[int] = None
"""Number of epochs to train on"""
# TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
max_tokens: Optional[int] = None
"""Total number of tokens to train on"""
max_steps: Optional[int] = None
"""Limits the number of optimizer steps to run"""
max_time: Optional[float] = None
"""Limits the number of seconds to train for"""
max_seq_length: Optional[int] = None
"""Limits the length of samples"""
tie_embeddings: Optional[bool] = None
"""Whether to tie the embedding weights with the language modeling head weights"""
# Optimization args
max_norm: Optional[float] = None
min_lr: float = 6e-5
def __post_init__(self) -> None:
if self.lr_warmup_fraction and self.lr_warmup_steps:
raise ValueError(
"Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one."
)
if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1):
raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.")
if self.lr_warmup_steps and self.max_steps and (self.lr_warmup_steps >= self.max_steps):
warnings.warn(
"`--train.lr_warmup_steps` should be less than `--train.max_steps`."
f" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.",
UserWarning,
)
def gradient_accumulation_iters(self, devices: int, num_nodes: int = 1) -> int:
"""Number of iterations between gradient synchronizations"""
gradient_accumulation_iters = self.batch_size(devices, num_nodes) // self.micro_batch_size
assert gradient_accumulation_iters > 0
return gradient_accumulation_iters
def batch_size(self, devices: int, num_nodes: int = 1) -> int:
"""Number of samples between optimizer steps per data-parallel rank"""
batch_size = self.global_batch_size // (devices * num_nodes)
assert batch_size > 0
return batch_size
def warmup_iters(self, devices: int, num_nodes: int, max_iters: int, train_dataloader) -> int:
"""Number of iterations to warm up the learning rate."""
if self.lr_warmup_fraction:
return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader)))
if self.lr_warmup_steps:
return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices, num_nodes))
return 0
@dataclass
class EvalArgs:
"""Evaluation-related arguments"""
interval: int = 600
"""Number of optimizer steps between evaluation calls"""
max_new_tokens: Optional[int] = None
"""Number of tokens to generate"""
max_iters: int = 100
"""Number of iterations"""
initial_validation: bool = False
"""Whether to evaluate on the validation set at the beginning of the training"""
final_validation: bool = True
"""Whether to evaluate on the validation set at the end of the training"""
evaluate_example: Union[str, int] = "first"
"""How to pick an example instruction to evaluate periodically during training.
Can be "first", "random", or an integer index to pick a specific example."""
@dataclass
class LogArgs:
"""Logging-related arguments. Different loggers use different fields."""
# === WandB Fields ===
project: Optional[str] = None
"""WandB project name"""
run: Optional[str] = None
"""WandB run name (defaults to generated name)"""
group: Optional[str] = None
"""WandB group name"""
# === LitLogger Fields (Lightning.ai) ===
teamspace: Optional[str] = None
"""Teamspace name where charts and artifacts will appear"""
metadata: Optional[Dict] = None
"""Extra metadata to associate with the experiment as tags"""
log_model: bool = False
"""If True, automatically log model checkpoints as artifacts"""
save_logs: bool = True
"""If True, capture and upload terminal logs"""
checkpoint_name: Optional[str] = None
"""Override the base name for logged checkpoints"""
================================================
FILE: litgpt/chat/__init__.py
================================================
================================================
FILE: litgpt/chat/base.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
from pathlib import Path
from pprint import pprint
from typing import Iterator, List, Literal, Optional, Tuple
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt.config import Config
from litgpt.model import GPT
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
)
@torch.inference_mode()
def generate(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
stop_tokens: Tuple[List[int], ...] = (),
) -> Iterator[torch.Tensor]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as possible.
Arguments:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
stop_tokens: If specified, stop generating any more token once one of this list is generated.
"""
from litgpt.generate.base import generate_fn
return generate_fn(
include_prompt=False,
include_eos=False,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens,
)
def process_prompt(
prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens
):
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
if max_new_tokens is None:
max_returned_tokens = model.max_seq_length
else:
first_turn = model.mask_cache is None
max_returned_tokens = encoded_prompt.size(0) + max_new_tokens
if first_turn or max_returned_tokens > model.max_seq_length:
model.max_seq_length = max_returned_tokens
model.set_kv_cache(batch_size=1, device=fabric.device)
y: Iterator[torch.Tensor] = generate(
model,
encoded_prompt,
max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens,
)
token_generator: Iterator[str] = tokenizer.decode_stream(y, device=fabric.device)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = 0
for tok in token_generator:
tokens_generated += 1
fabric.print(tok, end="", flush=True)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()
def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
while True:
try:
if not multiline:
prompt = input(">> Prompt: ")
else:
print(">> Prompt: (Type '!submit' on a new line to end input).")
prompt_lines = []
while True:
line = input()
if line.strip().lower() in ("!submit", "!quit", "!exit"):
break
prompt_lines.append(line)
prompt = "\n".join(prompt_lines)
except KeyboardInterrupt:
break
prompt = prompt.strip()
if not prompt or prompt.lower() in ("!quit", "!exit"):
break
process_prompt(
prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens
)
@torch.inference_mode()
def main(
checkpoint_dir: Path,
*,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
multiline: bool = False,
access_token: Optional[str] = None,
) -> None:
"""Chat with a model.
Args:
checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
You can get a list of valid model names via the `litgpt download list` command line argument.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
multiline: Whether to support multiline input prompts.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
# Merge if this is a raw LoRA checkpoint
checkpoint_path = checkpoint_dir / "lit_model.pth"
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_dir)
if not checkpoint_path.is_file():
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
with fabric.init_module(empty_init=True):
model = GPT(config)
if compile:
print(
"IMPORTANT: with enabled compilation the KV-cache size is determined by model's maximum context size, which leads to "
"a higher memory consumption. In case of an OOM error, try to set `--compile=False`."
)
model.set_kv_cache(batch_size=1)
load_checkpoint(fabric, model, checkpoint_path)
model.eval()
if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead", dynamic=True)
model = fabric.setup_module(model)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
stop_tokens = prompt_style.stop_tokens(tokenizer)
if multiline:
exit_instruction = "To exit, enter '!quit' or '!exit' on an empty prompt and press 'Enter'."
else:
exit_instruction = "To exit, press 'Enter' on an empty prompt."
print(f"Now chatting with {config.name}.\n{exit_instruction}\n")
L.seed_everything(1234)
interact(
multiline=multiline,
model=model,
tokenizer=tokenizer,
prompt_style=prompt_style,
fabric=fabric,
temperature=temperature,
max_new_tokens=(None if compile else max_new_tokens),
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens,
)
if fabric.device.type == "cuda":
fabric.print(f"\nMemory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/config.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Literal, Optional, Type, Union
import yaml
from typing_extensions import Self
def find_multiple(n: int, k: int) -> int:
"""Utility function for finding the nearest value to n which is a multiple of k.
NOTE: We define this function in this module rather than `litgpt.utils` so that users can import
this file to do configuration manipulations in Python environments which do not include all the dependencies
demanded by `litgpt.utils`.
"""
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
# General size parameters
block_size: int = 4096
n_layer: int = 16
n_embd: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
# Transformer block (structure, normalizations)
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
norm_qk: bool = False
norm_qk_type: Literal["default", "olmo2"] = "default"
post_attention_norm: bool = False
post_mlp_norm: bool = False
parallel_residual: bool = True
shared_attention_norm: bool = False
# Transformer block (self-attention)
n_head: int = 32
head_size: Optional[int] = None
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
# Example with `n_head=4`
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ │ │ │
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
attn_bias: bool = False
attention_scores_scalar: Optional[int] = None
# If `sliding_window_size` is given, sliding window attention with this
# size is used in layers where `sliding_window_indices` has a 1. The
# default is all 1, so that sliding window attention is used in all
# layers. If `len(sliding_window_indices) > n_layer`, we only use the
# initial part.
sliding_window_size: Optional[int] = None
sliding_window_indices: Optional[List[int]] = None
# if `attention_logit_softcapping` is used, cannot use optimized
# `torch.nn.functional.scaled_dot_product_attention` (which implements
# Flash attention), may result in higher memory and runtime footprint.
attention_logit_softcapping: Optional[float] = None
# Rotary position embedding (RoPE)
rope_base: int = 10000
rotary_percentage: float = 0.25
rope_condense_ratio: int = 1
rope_adjustments: Optional[dict] = None
rope_interleave: bool = False
# Transformer block (MLP)
intermediate_size: Optional[int] = None
moe_intermediate_size: Optional[int] = None
bias: bool = True
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
n_expert: int = 0
n_shared_expert: Optional[int] = None
n_expert_groups: Optional[int] = None
n_topk_groups: Optional[int] = None
n_topk_scores_per_group: Optional[int] = None
n_expert_per_token: int = 0
first_k_dense_replace: Optional[int] = None
routed_scaling_factor: float = 1.0
norm_topk_prob: bool = False
# GPT before/after blocks
scale_embeddings: bool = False
lm_head_bias: bool = False
final_logit_softcapping: Optional[float] = None
norm_1: bool = True
norm_2: bool = True
latent_attention: Optional[dict] = None
# The base period of the RoPE embeddings for local attention.
# If not provided, `rope_base` will be used for both local and global attention.
rope_local_base_freq: Optional[float] = None
# If provided, must have `>= n_layer` entries, either 0 or 1. For 0,
# `rope_base` is used, for 1 `rope_local_base_freq` is used. If
# `len(rope_indices) > n_layer`, we only use the initial part.
rope_indices: Optional[List[int]] = None
def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
else:
# vocab size shouldn't be larger than padded vocab size
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
# compute the number of query groups
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
# compute the intermediate size for MLP if not set
if self.intermediate_size is None:
if self.mlp_class_name == "LLaMAMLP":
raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`")
self.intermediate_size = 4 * self.n_embd
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
if self.sliding_window_size is not None:
self.sliding_window_indices = check_indicator_and_length(
self.sliding_window_indices,
name="sliding_window_indices",
required_length=self.n_layer,
)
if self.rope_local_base_freq is not None:
self.rope_indices = check_indicator_and_length(
self.rope_indices,
name="rope_indices",
required_length=self.n_layer,
)
if self.latent_attention is not None:
self.q_lora_rank = self.latent_attention.get("q_lora_rank")
self.kv_lora_rank = self.latent_attention.get("kv_lora_rank")
self.qk_rope_head_dim = self.latent_attention.get("qk_rope_head_dim")
self.qk_nope_head_dim = self.latent_attention.get("qk_nope_head_dim")
self.v_head_dim = self.latent_attention.get("v_head_dim")
assert (
self.q_lora_rank
and self.kv_lora_rank
and self.qk_rope_head_dim
and self.qk_nope_head_dim
and self.v_head_dim
) is not None
assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA"
self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim
self.rope_n_elem = self.qk_rope_head_dim
if self.first_k_dense_replace is not None:
assert self.mlp_class_name == "LLaMAMoE"
if self.n_expert_groups is not None:
assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1
assert self.n_topk_groups is not None
experts_per_group = self.n_expert // self.n_expert_groups
assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
if name not in name_to_config:
# search through all `config['hf_config']['name']`
try:
conf_dict = next(
config
for config in configs
if name == config["hf_config"]["name"]
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name
)
except StopIteration:
raise ValueError(f"{name!r} is not a supported config name")
else:
conf_dict = name_to_config[name]
conf_dict = conf_dict.copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@classmethod
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
file_kwargs = yaml.safe_load(fp)
if file_kwargs is None:
raise ValueError(f"{path} is empty which is likely unexpected.")
file_kwargs.update(kwargs)
return cls(**file_kwargs)
@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "model_config.yaml").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists.")
@property
def mlp_class(self) -> Type:
# `self.mlp_class_name` cannot be the type to keep the config serializable
import litgpt.model
return getattr(litgpt.model, self.mlp_class_name)
@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config serializable
from functools import partial
import torch # Torch import is lazy to make config loading faster
if self.norm_class_name == "RMSNorm":
from litgpt.model import RMSNorm
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
if self.norm_class_name == "LayerNorm" and "OLMo" in self.name:
# this makes it equivalent to `torch.nn.functional.layer_norm`
# that is used by OLMo
# Table 5 caption in the OLMo paper shows this - https://aclanthology.org/2024.acl-long.841
return partial(torch.nn.LayerNorm, elementwise_affine=False)
return getattr(torch.nn, self.norm_class_name)
def check_indicator_and_length(
params: Optional[List[int]],
name: str,
required_length: int,
use_initial_part: bool = True,
def_val: int = 1,
) -> List[int]:
if params is None:
return [def_val] * required_length
if len(params) != required_length:
if use_initial_part and len(params) > required_length:
params = params[:required_length]
else:
raise ValueError(f"{name} = {params}, must have length {required_length}")
if not set(params).issubset({0, 1}):
raise ValueError(f"{name} = {params}, must only contain 0 and 1")
return params
########################
# Stability AI StableLM
########################
configs = [
# https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
dict(name="stablelm-base-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b")),
# https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
dict(
name="stablelm-base-alpha-7b",
hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
n_head=48,
n_embd=6144,
padding_multiple=256,
),
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
dict(name="stablelm-tuned-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), n_head=32),
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
dict(
name="stablelm-tuned-alpha-7b",
hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
n_head=48,
n_embd=6144,
padding_multiple=256,
),
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
dict(
name="stablelm-3b-4e1t",
hf_config=dict(org="stabilityai", name="stablelm-3b-4e1t"),
padded_vocab_size=50304,
n_layer=32,
n_head=32,
n_embd=2560,
parallel_residual=False,
bias=False,
mlp_class_name="LLaMAMLP",
intermediate_size=6912,
),
# https://huggingface.co/stabilityai/stablelm-zephyr-3b/blob/main/config.json
dict(
name="stablelm-zephyr-3b",
hf_config=dict(org="stabilityai", name="stablelm-zephyr-3b"),
padded_vocab_size=50304,
n_layer=32,
n_head=32,
n_embd=2560,
parallel_residual=False,
bias=False,
mlp_class_name="LLaMAMLP",
intermediate_size=6912,
),
]
##########################
# Stability AI StableCode
##########################
stablecode = [
# https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
dict(
name="stablecode-completion-alpha-3b",
hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
block_size=16384,
vocab_size=49152,
n_layer=32,
n_embd=2560,
),
# https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
dict(
name="stablecode-completion-alpha-3b-4k",
hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
vocab_size=49152,
n_layer=32,
n_embd=2560,
),
# https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
dict(
name="stablecode-instruct-alpha-3b",
hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
vocab_size=49152,
n_layer=32,
n_embd=2560,
),
# https://huggingface.co/stabilityai/stable-code-3b/blob/main/config.json
dict(
name="stable-code-3b",
hf_config=dict(org="stabilityai", name="stable-code-3b"),
padded_vocab_size=50304,
n_layer=32,
n_embd=2560,
block_size=16384,
parallel_residual=False,
bias=False,
mlp_class_name="LLaMAMLP",
intermediate_size=6912,
),
]
configs.extend(stablecode)
####################
# EleutherAI Pythia
####################
pythia = [
# https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json
dict(
name="pythia-14m",
hf_config=dict(org="EleutherAI", name="pythia-14m"),
block_size=512,
n_layer=6,
n_embd=128,
n_head=4,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json
dict(
name="pythia-31m",
hf_config=dict(org="EleutherAI", name="pythia-31m"),
block_size=1024,
n_layer=6,
n_embd=256,
n_head=8,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
dict(
name="pythia-70m",
hf_config=dict(org="EleutherAI", name="pythia-70m"),
block_size=2048,
n_layer=6,
n_embd=512,
n_head=8,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
dict(
name="pythia-160m",
hf_config=dict(org="EleutherAI", name="pythia-160m"),
block_size=2048,
n_layer=12,
n_embd=768,
n_head=12,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
dict(
name="pythia-410m",
hf_config=dict(org="EleutherAI", name="pythia-410m"),
block_size=2048,
n_layer=24,
n_embd=1024,
n_head=16,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
dict(
name="pythia-1b",
hf_config=dict(org="EleutherAI", name="pythia-1b"),
block_size=2048,
n_embd=2048,
n_head=8,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
dict(
name="pythia-1.4b",
hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
block_size=2048,
n_layer=24,
n_embd=2048,
n_head=16,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
dict(
name="pythia-2.8b",
hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
block_size=2048,
n_layer=32,
n_embd=2560,
padding_multiple=128,
),
# https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
dict(
name="pythia-6.9b",
hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
block_size=2048,
n_layer=32,
padding_multiple=256,
),
# https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
dict(
name="pythia-12b",
hf_config=dict(org="EleutherAI", name="pythia-12b"),
block_size=2048,
n_layer=36,
n_embd=5120,
n_head=40,
),
]
configs.extend(pythia)
for c in pythia:
# "pythia-14m" and "pythia-31m" don't have deduped version
if c["name"] in ("pythia-14m", "pythia-31m"):
continue
copy = deepcopy(c)
copy["name"] = f"{c['name']}-deduped"
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
configs.append(copy)
#################
# TII UAE Falcon
#################
falcon = [
# https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
dict(
name="falcon-7b{}",
hf_config=dict(org="tiiuae", name="falcon-7b{}"),
block_size=2048,
vocab_size=65024,
padded_vocab_size=65024,
n_layer=32,
n_head=71,
n_embd=4544,
rotary_percentage=1.0,
n_query_groups=1,
bias=False,
# this is not in the config, but in the original model implementation, only for this config
shared_attention_norm=True,
),
# https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
dict(
name="falcon-40b{}",
hf_config=dict(org="tiiuae", name="falcon-40b{}"),
block_size=2048,
vocab_size=65024,
padded_vocab_size=65024,
n_layer=60,
n_head=128,
n_embd=8192,
rotary_percentage=1.0,
n_query_groups=8,
bias=False,
),
]
for c in falcon:
for kind in ("", "-instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
falcon180b = dict(
name="falcon-180B{}",
hf_config=dict(org="tiiuae", name="falcon-180B{}"),
block_size=2048,
vocab_size=65024,
padded_vocab_size=65024,
n_layer=80,
n_head=232,
n_embd=14848,
rotary_percentage=1.0,
n_query_groups=8,
bias=False,
)
for kind in ("", "-chat"):
copy = deepcopy(falcon180b)
copy["name"] = falcon180b["name"].format(kind)
copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
configs.append(copy)
falcon3 = [
# https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json
dict(
name="Falcon3-1B{}",
hf_config=dict(org="tiiuae", name="Falcon3-1B{}"),
block_size=4096,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=18,
n_head=8,
n_query_groups=4,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
),
# https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json
dict(
name="Falcon3-3B{}",
hf_config=dict(org="tiiuae", name="Falcon3-3B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=22,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9216,
),
# https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json
dict(
name="Falcon3-7B{}",
hf_config=dict(org="tiiuae", name="Falcon3-7B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=28,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=23040,
),
# https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json
dict(
name="Falcon3-10B{}",
hf_config=dict(org="tiiuae", name="Falcon3-10B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=40,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=23040,
),
]
for c in falcon3:
for kind in ("-Base", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
#############################
# OpenLM Research Open LLaMA
#############################
open_LLaMA = [
# https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
dict(
name="open_llama_3b",
hf_config=dict(org="openlm-research", name="open_llama_3b"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=26,
n_embd=3200,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-6,
mlp_class_name="LLaMAMLP",
intermediate_size=8640,
),
# https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
dict(
name="open_llama_7b",
hf_config=dict(org="openlm-research", name="open_llama_7b"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-6,
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
dict(
name="open_llama_13b",
hf_config=dict(org="openlm-research", name="open_llama_13b"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-6,
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
),
]
configs.extend(open_LLaMA)
###############
# Meta LLaMA 2
###############
llama_2 = [
# https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
dict(
name="Llama-2-7b{}-hf",
hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
vocab_size=32000,
padding_multiple=64,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
dict(
name="Llama-2-13b{}-hf",
hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
dict(
name="Llama-2-70b{}-hf",
hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
vocab_size=32000,
padding_multiple=64,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
),
]
for c in llama_2:
for kind in ("", "-chat"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
###############
# Meta LLaMA 3
###############
llama_3 = [
# https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json
dict(
name="Llama-3-8B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"),
block_size=8192,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=32,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
),
# https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/config.json
dict(
name="Llama-3.1-8B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-8B{}"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=32,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json
dict(
name="Llama-3-70B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"),
block_size=8192,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
),
# https://huggingface.co/meta-llama/Meta-Llama-3.1-70B/blob/main/config.json
dict(
name="Llama-3.1-70B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-70B{}"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json
dict(
name="Llama-3.1-405B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3.1-405B{}"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=126,
n_head=128,
n_embd=16384,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=53248,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
dict(
name="Llama-3.2-1B{}",
hf_config=dict(org="meta-llama", name="Llama-3.2-1B{}"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=16,
n_embd=2048,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json
dict(
name="Llama-3.2-3B{}",
hf_config=dict(org="meta-llama", name="Llama-3.2-3B{}"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=28,
n_embd=3072,
n_head=24,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json
dict(
name="Llama-3.3-70B-Instruct",
hf_config=dict(org="meta-llama", name="Llama-3.3-70B-Instruct"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
]
for c in llama_3:
if c["name"] == "Llama-3.3-70B-Instruct":
configs.append(c)
continue
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
#########################
# NVIDIA Llama Nemotron
#########################
configs.append(
dict(
name="Llama-3.1-Nemotron-70B-Instruct-HF",
hf_config=dict(org="nvidia", name="Llama-3.1-Nemotron-70B-Instruct-HF"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
)
#################
# Allen AI OLMo
#################
olmo = [
# https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json
dict(
name="OLMo-1B-hf",
hf_config=dict(org="allenai", name="OLMo-1B-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_embd=2048,
n_layer=16,
n_head=16,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
),
# https://huggingface.co/allenai/OLMo-7B-hf/blob/main/config.json
dict(
name="OLMo-7B-hf",
hf_config=dict(org="allenai", name="OLMo-7B-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/allenai/OLMo-7B-Instruct-hf/blob/main/config.json
dict(
name="OLMo-7B-Instruct-hf",
hf_config=dict(org="allenai", name="OLMo-7B-Instruct-hf"),
vocab_size=50280,
padded_vocab_size=50304,
block_size=2048,
n_layer=32,
n_head=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="LayerNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
]
configs.extend(olmo)
olmo2 = [
# https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json
dict(
name="OLMo-2-1124-7B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=4096,
n_layer=32,
n_head=32,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=11008,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
norm_1=False,
norm_2=False,
norm_qk_type="olmo2",
post_attention_norm=True,
),
# https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json
dict(
name="OLMo-2-1124-13B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=40,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=13824,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
norm_1=False,
norm_2=False,
norm_qk_type="olmo2",
post_attention_norm=True,
),
]
for c in olmo2:
for kind in ("", "-SFT", "-DPO", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
###############
# Google Gemma
###############
gemma = [
# https://huggingface.co/google/gemma-2b/blob/main/config.json
dict(
name="Gemma-2b",
hf_config=dict(org="google", name="gemma-2b"),
scale_embeddings=True,
vocab_size=256000,
padding_multiple=64,
n_embd=2048,
n_layer=18,
n_head=8,
n_query_groups=1,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=16384,
),
# https://huggingface.co/google/gemma-7b/blob/main/config.json
dict(
name="Gemma-7b",
hf_config=dict(org="google", name="gemma-7b"),
scale_embeddings=True,
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=24576,
),
# https://huggingface.co/google/gemma-2-2b/blob/main/config.json
dict(
name="Gemma-2-2b",
hf_config=dict(org="google", name="gemma-2-2b"),
scale_embeddings=True,
attention_scores_scalar=256,
vocab_size=256000,
block_size=8192,
sliding_window_size=4096,
# only layer with idx 0, 2, 4, ... have sliding window attention
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(26)],
intermediate_size=9216,
n_embd=2304,
n_layer=26,
n_head=8,
n_query_groups=4,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
attention_logit_softcapping=50.0,
final_logit_softcapping=30.0,
),
# https://huggingface.co/google/gemma-2-9b/blob/main/config.json
dict(
name="Gemma-2-9b",
hf_config=dict(org="google", name="gemma-2-9b"),
scale_embeddings=True,
attention_scores_scalar=256,
vocab_size=256000,
block_size=8192,
sliding_window_size=4096,
# only layer with idx 0, 2, 4, ... have sliding window attention
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(42)],
intermediate_size=14336,
n_embd=3584,
n_layer=42,
n_head=16,
n_query_groups=8,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
attention_logit_softcapping=50.0,
final_logit_softcapping=30.0,
),
# https://huggingface.co/google/gemma-2-27b/blob/main/config.json
dict(
name="Gemma-2-27b",
hf_config=dict(org="google", name="gemma-2-27b"),
scale_embeddings=True,
# In Gemma 2 27B attention scores are scaled not by `sqrt(head_size)` (11.31),
# but by `sqrt(n_emb // n_head)` = sqrt(4608 // 32) = 12
attention_scores_scalar=144,
vocab_size=256000,
block_size=8192,
sliding_window_size=4096,
# only layer with idx 0, 2, 4, ... have sliding window attention
sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(46)],
intermediate_size=36864,
n_embd=4608,
n_layer=46,
n_head=32,
n_query_groups=16,
head_size=128,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
attention_logit_softcapping=50.0,
final_logit_softcapping=30.0,
),
]
configs.extend(gemma)
for c in gemma:
copy = deepcopy(c)
copy["name"] = f"{c['name']}-it"
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it"
configs.append(copy)
##################
# Google Gemma 3
##################
gemma3 = [
# https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json
dict(
name="Gemma-3-1b-it",
hf_config=dict(org="google", name="gemma-3-1b-it"),
scale_embeddings=True,
attention_scores_scalar=256,
vocab_size=262144,
block_size=131072,
sliding_window_size=512,
# 5 local layers for every global layer
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
intermediate_size=6912,
n_embd=1152,
n_layer=26,
n_head=4,
n_query_groups=1,
head_size=256,
rotary_percentage=1.0,
rope_adjustments=None,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
norm_qk=True,
rope_base=1000000,
rope_local_base_freq=10000,
# 5 local layers for every global layer
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
),
# https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
dict(
name="Gemma-3-4b-it",
hf_config=dict(org="google", name="gemma-3-4b-it"),
scale_embeddings=True,
attention_scores_scalar=256,
vocab_size=262144,
block_size=131072,
sliding_window_size=1024,
# 5 local layers for every global layer
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
intermediate_size=10240,
n_embd=2560,
n_layer=34,
n_head=8,
n_query_groups=4,
head_size=256,
rotary_percentage=1.0,
rope_adjustments=dict(factor=8.0),
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
norm_qk=True,
rope_base=1000000,
rope_local_base_freq=10000,
# 5 local layers for every global layer
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
),
# https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json
dict(
name="Gemma-3-12b-it",
hf_config=dict(org="google", name="gemma-3-12b-it"),
scale_embeddings=True,
attention_scores_scalar=256,
vocab_size=262144,
block_size=131072,
sliding_window_size=1024,
# 5 local layers for every global layer
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],
intermediate_size=15360,
n_embd=3840,
n_layer=48,
n_head=16,
n_query_groups=8,
head_size=256,
rotary_percentage=1.0,
rope_adjustments=dict(factor=8.0),
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
norm_qk=True,
rope_base=1000000,
rope_local_base_freq=10000,
# 5 local layers for every global layer
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],
),
# https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json
dict(
name="Gemma-3-27b-it",
hf_config=dict(org="google", name="gemma-3-27b-it"),
scale_embeddings=True,
attention_scores_scalar=168,
vocab_size=262144,
block_size=131072,
sliding_window_size=1024,
# 5 local layers for every global layer
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],
intermediate_size=21504,
n_embd=5376,
n_layer=62,
n_head=32,
n_query_groups=16,
head_size=128,
rotary_percentage=1.0,
rope_adjustments=dict(factor=8.0),
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
post_attention_norm=True,
post_mlp_norm=True,
norm_qk=True,
rope_base=1000000,
rope_local_base_freq=10000,
# 5 local layers for every global layer
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],
),
]
configs.extend(gemma3)
##################
# Google CodeGemma
##################
codegemma = [
# https://huggingface.co/google/codegemma-7b-it/blob/main/config.json
dict(
name="CodeGemma-7b-it",
hf_config=dict(org="google", name="codegemma-7b-it"),
scale_embeddings=True,
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=24576,
),
]
configs.extend(codegemma)
##########################
# Stability AI FreeWilly2
##########################
freewilly_2 = [
# https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
dict(
name="FreeWilly2",
hf_config=dict(org="stabilityai", name="FreeWilly2"),
vocab_size=32000,
padding_multiple=64,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
)
]
configs.extend(freewilly_2)
##################
# Meta Code Llama
##################
code_llama = [
# https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
dict(
name="CodeLlama-7b-hf",
hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
dict(
name="CodeLlama-13b-hf",
hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
dict(
name="CodeLlama-34b-hf",
hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
block_size=16384,
vocab_size=32000,
padded_vocab_size=32000,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-70b-hf/blob/main/config.json
dict(
name="CodeLlama-70b-hf",
hf_config=dict(org="codellama", name="CodeLlama-70b-hf"),
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
dict(
name="CodeLlama-7b-Python-hf",
hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
block_size=16384,
vocab_size=32000,
padded_vocab_size=32000,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
dict(
name="CodeLlama-13b-Python-hf",
hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
block_size=16384,
vocab_size=32000,
padded_vocab_size=32000,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
dict(
name="CodeLlama-34b-Python-hf",
hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
block_size=16384,
vocab_size=32000,
padded_vocab_size=32000,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-70b-Python-hf/blob/main/config.json
dict(
name="CodeLlama-70b-Python-hf",
hf_config=dict(org="codellama", name="CodeLlama-70b-Python-hf"),
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json
dict(
name="CodeLlama-7b-Instruct-hf",
hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
block_size=16384,
vocab_size=32016,
padding_multiple=16,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
dict(
name="CodeLlama-13b-Instruct-hf",
hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
block_size=2048,
vocab_size=32016,
padding_multiple=16,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
dict(
name="CodeLlama-34b-Instruct-hf",
hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
block_size=16384,
vocab_size=32000,
padded_vocab_size=32000,
n_layer=48,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=22016,
rope_base=1000000,
),
# https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/config.json
dict(
name="CodeLlama-70b-Instruct-hf",
hf_config=dict(org="codellama", name="CodeLlama-70b-Instruct-hf"),
block_size=16384,
# 32016 is an added token, so not reported in vocab_size
# https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/tokenizer_config.json
vocab_size=32015,
padding_multiple=16,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=1000000,
),
]
configs.extend(code_llama)
########################
# garage-bAInd Platypus
########################
platypus = [
# https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
dict(
name="Platypus-30B",
hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
block_size=2048,
padded_vocab_size=32000,
n_layer=60,
n_head=52,
n_embd=6656,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-06,
mlp_class_name="LLaMAMLP",
intermediate_size=17920,
),
# https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
dict(
name="Platypus2-7B",
hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
padded_vocab_size=32000,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
dict(
name="Platypus2-13B",
hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
padded_vocab_size=32000,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
dict(
name="Platypus2-70B",
hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
padded_vocab_size=32000,
n_layer=80,
n_head=64,
n_embd=8192,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
),
# https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
dict(
name="Camel-Platypus2-13B",
hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
padded_vocab_size=32000,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
dict(
name="Camel-Platypus2-70B",
hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
padded_vocab_size=32000,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
),
# https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
dict(
name="Stable-Platypus2-13B",
hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
padded_vocab_size=32000,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
dict(
name="Platypus2-70B-instruct",
hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
padded_vocab_size=32000,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
),
]
configs.extend(platypus)
##################################
# togethercomputer LLaMA-2-7B-32K
##################################
together_llama2_32k = [
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
dict(
name="LLaMA-2-7B-32K",
hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
vocab_size=32000,
padding_multiple=64,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_condense_ratio=8,
)
]
configs.extend(together_llama2_32k)
################
# Microsoft Phi
################
phi = [
# https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
dict(
name="phi-1_5",
hf_config=dict(org="microsoft", name="phi-1_5"),
vocab_size=50257,
padded_vocab_size=51200,
block_size=2048,
n_embd=2048,
n_layer=24,
rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
shared_attention_norm=True,
lm_head_bias=True,
gelu_approximate="tanh",
),
# https://huggingface.co/microsoft/phi-2/blob/main/config.json
dict(
name="phi-2",
hf_config=dict(org="microsoft", name="phi-2"),
vocab_size=50257,
padded_vocab_size=51200,
block_size=2048,
n_embd=2560,
n_layer=32,
rotary_percentage=0.4, # 32 / (n_embd / n_head) = 32 / 80
shared_attention_norm=True,
lm_head_bias=True,
gelu_approximate="tanh",
),
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
dict(
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"),
vocab_size=32000,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
sliding_window_size=2048,
),
# https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json
dict(
name="Phi-3-mini-128k-instruct",
hf_config=dict(org="microsoft", name="Phi-3-mini-128k-instruct"),
vocab_size=32000,
padded_vocab_size=32064,
block_size=131072,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
sliding_window_size=262145,
),
# https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json
dict(
name="Phi-3.5-mini-instruct",
hf_config=dict(org="microsoft", name="Phi-3.5-mini-instruct"),
vocab_size=32000,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
# https://huggingface.co/microsoft/phi-4/blob/main/config.json
dict(
name="phi-4",
hf_config=dict(org="microsoft", name="phi-4"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=16384,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=10,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=17920,
rope_base=250000,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
# https://huggingface.co/microsoft/Phi-4-reasoning/blob/main/config.json
dict(
name="Phi-4-reasoning",
hf_config=dict(org="microsoft", name="Phi-4-reasoning"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=32768,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=10,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=17920,
rope_base=500000,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
# https://huggingface.co/microsoft/Phi-4-reasoning-plus/blob/main/config.json
dict(
name="Phi-4-reasoning-plus",
hf_config=dict(org="microsoft", name="Phi-4-reasoning-plus"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=32768,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=10,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=17920,
rope_base=500000,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
# https://huggingface.co/microsoft/Phi-4-mini-instruct/blob/main/config.json
dict(
name="Phi-4-mini-instruct",
hf_config=dict(org="microsoft", name="Phi-4-mini-instruct"),
vocab_size=200019,
padded_vocab_size=200064,
block_size=131072,
n_embd=3072,
n_layer=32,
n_head=24,
n_query_groups=8,
rotary_percentage=0.75,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
sliding_window_size=262145,
),
# https://huggingface.co/microsoft/Phi-4-mini-reasoning/blob/main/config.json
dict(
name="Phi-4-mini-reasoning",
hf_config=dict(org="microsoft", name="Phi-4-mini-reasoning"),
vocab_size=200019,
padded_vocab_size=200064,
block_size=131072,
n_embd=3072,
n_layer=32,
n_head=24,
n_query_groups=8,
rotary_percentage=0.75,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
sliding_window_size=262145,
),
]
configs.extend(phi)
#############
# Mistral AI
#############
configs.append(
# https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json
dict(
name="Mathstral-7B-v0.1",
hf_config=dict(org="mistralai", name="mathstral-7B-v0.1"),
padded_vocab_size=32768,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
sliding_window_size=4096,
)
)
mistral = [
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
dict(
name="Mistral-7B-{}v0.1",
hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
padded_vocab_size=32000,
block_size=4096, # should be 32768 but sliding window attention is not implemented
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
sliding_window_size=4096,
),
# https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
dict(
name="Mixtral-8x7B-{}v0.1",
hf_config=dict(org="mistralai", name="Mixtral-8x7B-{}v0.1"),
padded_vocab_size=32000,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMoE",
intermediate_size=14336,
rope_base=1000000,
n_expert=8,
n_expert_per_token=2,
),
# https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json
dict(
name="Mixtral-8x22B-{}v0.1",
hf_config=dict(org="mistralai", name="Mixtral-8x22B-{}v0.1"),
padded_vocab_size=32768,
block_size=65536,
n_layer=56,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMoE",
intermediate_size=16384,
n_head=48,
n_embd=6144,
rope_base=1000000,
n_expert=8,
n_expert_per_token=2,
),
]
for c in mistral:
for kind in ("", "Instruct-"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
configs.append(
# https://huggingface.co/unsloth/mistral-7b-v0.2/blob/main/config.json
dict(
name="Mistral-7B-v0.2",
hf_config=dict(org="unsloth", name="Mistral-7B-v0.2"),
padded_vocab_size=32000,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
)
)
configs.append(
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json
dict(
name="Mistral-7B-Instruct-v0.2",
hf_config=dict(org="mistralai", name="Mistral-7B-Instruct-v0.2"),
padded_vocab_size=32000,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
)
)
configs.append(
# https://huggingface.co/mistralai/Mistral-7B-v0.3/blob/main/config.json
dict(
name="Mistral-7B-v0.3",
hf_config=dict(org="mistralai", name="Mistral-7B-v0.3"),
padded_vocab_size=32768,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
)
)
configs.append(
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/config.json
dict(
name="Mistral-7B-Instruct-v0.3",
hf_config=dict(org="mistralai", name="Mistral-7B-Instruct-v0.3"),
padded_vocab_size=32768,
block_size=32768,
n_layer=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
)
)
configs.append(
# https://huggingface.co/mistralai/Mistral-Large-Instruct-2407/blob/main/config.json
dict(
name="Mistral-Large-Instruct-2407",
hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2407"),
padded_vocab_size=32768,
block_size=32768,
n_layer=88,
n_head=96,
n_embd=12288,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
)
)
configs.append(
# https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/config.json
dict(
name="Mistral-Large-Instruct-2411",
hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2411"),
padded_vocab_size=32768,
block_size=32768,
n_layer=88,
n_head=96,
n_embd=12288,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
)
)
############
# TinyLlama
############
tiny_llama = [
dict(
name="tiny-llama-1.1b{}",
hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=22,
n_head=32,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
for c in tiny_llama:
for kind, hf_postfix in (("", "-intermediate-step-1431k-3T"), ("-chat", "-Chat-v1.0")):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix)
configs.append(copy)
############
# MicroLlama
############
micro_llama = [
dict(
name="micro-llama-300M",
hf_config=dict(org="keeeeenw", name="MicroLlama"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=16,
n_embd=1024,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
configs.extend(micro_llama)
##########################
# Trelis Function Calling
##########################
llama_2_function_calling = [
# https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2/blob/main/config.json
dict(
name="Llama-2-7b-chat-hf-function-calling-v2",
hf_config=dict(org="Trelis", name="Llama-2-7b-chat-hf-function-calling-v2"),
padding_multiple=64,
n_layer=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
block_size=4096,
vocab_size=32000,
n_head=32,
n_embd=4096,
rope_base=10000,
)
]
configs.extend(llama_2_function_calling)
##########
# Qwen2.5
##########
qwen_2_5 = [
# https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json
dict(
name="Qwen2.5-0.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-0.5B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=24,
n_head=14,
n_embd=896,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=4864,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-1.5B/blob/main/config.json
dict(
name="Qwen2.5-1.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-1.5B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=12,
n_embd=1536,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8960,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json
dict(
name="Qwen2.5-3B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-3B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=16,
n_embd=2048,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json
dict(
name="Qwen2.5-7B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-7B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-14B/blob/main/config.json
dict(
name="Qwen2.5-14B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-14B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=48,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
norm_eps=1e-5,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-32B/blob/main/config.json
dict(
name="Qwen2.5-32B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-32B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-72B/blob/main/config.json
dict(
name="Qwen2.5-72B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-72B{}"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=29568,
norm_eps=1e-5,
rope_base=1000000,
),
]
qwen_2_5_coder = [
# https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B/blob/main/config.json
dict(
name="Qwen2.5-Coder-0.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-0.5B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=24,
n_head=14,
n_embd=896,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=4864,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B/blob/main/config.json
dict(
name="Qwen2.5-Coder-1.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=12,
n_embd=1536,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8960,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-3B/blob/main/config.json
dict(
name="Qwen2.5-Coder-3B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-3B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=16,
n_embd=2048,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-7B/blob/main/config.json
dict(
name="Qwen2.5-Coder-7B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-6,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-14B/blob/main/config.json
dict(
name="Qwen2.5-Coder-14B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=48,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
norm_eps=1e-5,
rope_base=1000000,
),
# https://huggingface.co/Qwen/Qwen2.5-Coder-32B/blob/main/config.json
dict(
name="Qwen2.5-Coder-32B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000,
),
]
qwen_2_5.extend(qwen_2_5_coder)
qwen_2_5_math = [
# https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json
dict(
name="Qwen2.5-Math-1.5B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Math-1.5B{}"),
block_size=4096,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=12,
n_embd=1536,
n_query_groups=2,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8960,
norm_eps=1e-6,
rope_base=10000,
),
# https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json
dict(
name="Qwen2.5-Math-7B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Math-7B{}"),
block_size=4096,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-6,
rope_base=10000,
),
# https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json
dict(
name="Qwen2.5-Math-72B{}",
hf_config=dict(org="Qwen", name="Qwen2.5-Math-72B{}"),
block_size=4096,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=29568,
norm_eps=1e-5,
rope_base=10000,
),
]
qwen_2_5.extend(qwen_2_5_math)
for c in qwen_2_5:
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
qwen_2_5_1m = [
# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-1M/blob/main/config.json
dict(
name="Qwen2.5-7B-Instruct-1M",
hf_config=dict(org="Qwen", name="Qwen2.5-7B-Instruct-1M"),
block_size=1010000,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=28,
n_head=28,
n_embd=3584,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=18944,
norm_eps=1e-5,
rope_base=10000000,
),
# https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json
dict(
name="Qwen2.5-14B-Instruct-1M",
hf_config=dict(org="Qwen", name="Qwen2.5-14B-Instruct-1M"),
block_size=1010000,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=48,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=13824,
norm_eps=1e-5,
rope_base=10000000,
),
]
configs.extend(qwen_2_5_1m)
##########
# QwQ
##########
qwq = [
# https://huggingface.co/Qwen/QwQ-32B/blob/main/config.json
dict(
name="QwQ-32B",
hf_config=dict(org="Qwen", name="QwQ-32B"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000,
),
# https://huggingface.co/Qwen/QwQ-32B-Preview/blob/main/config.json
dict(
name="QwQ-32B-Preview",
hf_config=dict(org="Qwen", name="QwQ-32B-Preview"),
block_size=32768,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000,
),
]
configs.extend(qwq)
##########
# Qwen3
##########
qwen_3 = [
# https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
dict(
name="Qwen3-0.6B{}",
hf_config=dict(org="Qwen", name="Qwen3-0.6B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=1024,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=3072,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json
dict(
name="Qwen3-1.7B{}",
hf_config=dict(org="Qwen", name="Qwen3-1.7B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=2048,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=6144,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json
dict(
name="Qwen3-4B{}",
hf_config=dict(org="Qwen", name="Qwen3-4B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=2560,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9728,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json
dict(
name="Qwen3-8B{}",
hf_config=dict(org="Qwen", name="Qwen3-8B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=4096,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=12288,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json
dict(
name="Qwen3-14B{}",
hf_config=dict(org="Qwen", name="Qwen3-14B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=40,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=17408,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
]
for c in qwen_3:
for kind in ("", "-Base"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
qwen_3_32b = [
# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json
dict(
name="Qwen3-32B",
hf_config=dict(org="Qwen", name="Qwen3-32B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=64,
n_head=64,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=25600,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
]
configs.extend(qwen_3_32b)
qwen_3_moe = [
# https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json
dict(
name="Qwen3-30B-A3B",
hf_config=dict(org="Qwen", name="Qwen3-30B-A3B"),
block_size=40960,
head_size=128,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=48,
n_head=32,
n_embd=2048,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMoE",
intermediate_size=6144,
moe_intermediate_size=768,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
n_expert=128,
n_expert_per_token=8,
),
# https://huggingface.co/Qwen/Qwen3-30B-A3B-Base/blob/main/config.json
dict(
name="Qwen3-30B-A3B-Base",
hf_config=dict(org="Qwen", name="Qwen3-30B-A3B-Base"),
block_size=40960,
head_size=128,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=48,
n_head=32,
n_embd=2048,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMoE",
intermediate_size=6144,
moe_intermediate_size=768,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
n_expert=128,
n_expert_per_token=8,
),
# https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json
dict(
name="Qwen3-235B-A22B",
hf_config=dict(org="Qwen", name="Qwen3-235B-A22B"),
block_size=40960,
head_size=128,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=94,
n_head=64,
n_embd=4096,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMoE",
intermediate_size=12288,
moe_intermediate_size=1536,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
n_expert=128,
n_expert_per_token=8,
),
]
configs.extend(qwen_3_moe)
qwen_3_2507_thinking_instruct = [
# https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507/blob/main/config.json
dict(
name="Qwen3-235B-A22B-{}-2507",
hf_config=dict(org="Qwen", name="Qwen3-235B-A22B-{}-2507"),
block_size=262144,
head_size=128,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=94,
n_head=64,
n_embd=4096,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMoE",
intermediate_size=12288,
moe_intermediate_size=1536,
norm_eps=1e-6,
rope_base=5000000,
norm_qk=True,
n_expert=128,
n_expert_per_token=8,
),
# https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507/blob/main/config.json
dict(
name="Qwen3-30B-A3B-{}-2507",
hf_config=dict(org="Qwen", name="Qwen3-30B-A3B-{}-2507"),
block_size=262144,
head_size=128,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=48,
n_head=32,
n_embd=2048,
n_query_groups=4,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMoE",
intermediate_size=6144,
moe_intermediate_size=768,
norm_eps=1e-6,
rope_base=10000000,
norm_qk=True,
n_expert=128,
n_expert_per_token=8,
),
# https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507/blob/main/config.json
dict(
name="Qwen3-4B-{}-2507",
hf_config=dict(org="Qwen", name="Qwen3-4B-{}-2507"),
block_size=262144,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=2560,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9728,
norm_eps=1e-6,
rope_base=5000000,
head_size=128,
norm_qk=True,
),
]
for c in qwen_3_2507_thinking_instruct:
for kind in ("Thinking", "Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
#############
# Salamandra
#############
salamandra = [
# https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json
dict(
name="salamandra-2b{}",
hf_config=dict(org="BSC-LT", name="salamandra-2b{}"),
block_size=8192,
vocab_size=256000,
padded_vocab_size=256000,
n_layer=24,
n_head=16,
n_embd=2048,
n_query_groups=16,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=5440,
norm_eps=1e-5,
rope_base=10000,
),
# https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json
dict(
name="salamandra-7b{}",
hf_config=dict(org="BSC-LT", name="salamandra-7b{}"),
block_size=8192,
vocab_size=256000,
padded_vocab_size=256000,
n_layer=32,
n_head=32,
n_embd=4096,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
norm_eps=1e-6,
rope_base=10000,
),
]
for c in salamandra:
for kind in ("", "-instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
###############
# SmolLM2
###############
smollm2 = [
# https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json
dict(
name="SmolLM2-135M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=30,
n_head=9,
n_embd=576,
n_query_groups=3,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=1536,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json
dict(
name="SmolLM2-360M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=32,
n_head=15,
n_embd=960,
n_query_groups=5,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=2560,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json
dict(
name="SmolLM2-1.7B{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=24,
n_head=32,
n_embd=2048,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=130000,
norm_eps=1e-5,
),
]
for c in smollm2:
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
###############
# DeepSeek R1 Distill
###############
r1_distill_llama = [
# https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/blob/main/config.json
dict(
name="R1-Distill-Llama-8B",
hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-8B"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=32,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
# https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/blob/main/config.json
dict(
name="R1-Distill-Llama-70B",
hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-70B"),
block_size=131072,
vocab_size=128000,
padded_vocab_size=128256,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),
),
]
configs.extend(r1_distill_llama)
name_to_config = {config["name"]: config for config in configs}
================================================
FILE: litgpt/constants.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Centralized package availability constants for optional dependencies."""
from lightning_utilities.core.imports import RequirementCache
# Logger-related constants
_SUPPORTED_LOGGERS: tuple[str, ...] = ("csv", "tensorboard", "wandb", "mlflow", "litlogger")
# Logger-related optional dependencies
_LITLOGGER_AVAILABLE = RequirementCache("litlogger>=0.1.7")
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_WANDB_AVAILABLE = RequirementCache("wandb")
_MLFLOW_AVAILABLE = RequirementCache("mlflow")
_MLFLOW_SKINNY_AVAILABLE = RequirementCache("mlflow-skinny")
# PyTorch version-specific constants
_TORCH_EQUAL_2_7 = RequirementCache("torch>=2.7.0,<2.8")
_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9")
# Other optional dependencies
_REQUESTS_AVAILABLE = RequirementCache("requests")
_THUNDER_AVAILABLE = RequirementCache("thunder")
_TRITON_AVAILABLE = RequirementCache("triton")
_BITANDBYTES_AVAILABLE = RequirementCache("bitsandbytes")
_BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 = RequirementCache("bitsandbytes != 0.42.0")
_LITDATA_AVAILABLE = RequirementCache("litdata")
_LITSERVE_AVAILABLE = RequirementCache("litserve")
_JINJA2_AVAILABLE = RequirementCache("jinja2")
_SAFETENSORS_AVAILABLE = RequirementCache("safetensors")
_HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer")
================================================
FILE: litgpt/data/__init__.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from litgpt.data.alpaca import Alpaca
from litgpt.data.alpaca_2k import Alpaca2k
from litgpt.data.alpaca_gpt4 import AlpacaGPT4
from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.deita import Deita
from litgpt.data.flan import FLAN
from litgpt.data.json_data import JSON
from litgpt.data.lima import LIMA
from litgpt.data.lit_data import LitData
from litgpt.data.longform import LongForm
from litgpt.data.microllama import MicroLlama
from litgpt.data.openwebtext import OpenWebText
from litgpt.data.text_files import TextFiles
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
__all__ = [
"Alpaca",
"Alpaca2k",
"AlpacaGPT4",
"Deita",
"FLAN",
"JSON",
"LIMA",
"LitData",
"DataModule",
"LongForm",
"OpenWebText",
"SFTDataset",
"TextFiles",
"TinyLlama",
"TinyStories",
"MicroLlama",
"get_sft_collate_fn",
]
================================================
FILE: litgpt/data/alpaca.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
import torch
from torch.utils.data import DataLoader, random_split
from litgpt.constants import _REQUESTS_AVAILABLE
from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
_URL = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
@dataclass
class Alpaca(DataModule):
"""Alpaca data module for supervised finetuning."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
val_split_fraction: float = 0.03865 # to get exactly 2000 validation samples,
"""The fraction of the dataset to use for the validation dataset. The rest is used for training."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/alpaca")
"""The directory in which the downloaded dataset gets saved."""
file_url: str = field(repr=False, default=_URL)
"""The URL from where to download the dataset."""
file_name: str = field(repr=False, default="alpaca_data_cleaned_archive.json")
"""The name of the dataset file to download."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
super().__init__()
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def prepare_data(self) -> None:
self.download_dir.mkdir(parents=True, exist_ok=True)
download_if_missing(self.download_dir / self.file_name, self.file_url)
def setup(self, stage: str = "") -> None:
with open(self.download_dir / self.file_name, encoding="utf-8") as file:
data = json.load(file)
# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.val_split_fraction, self.val_split_fraction],
generator=torch.Generator().manual_seed(self.seed),
)
train_data, test_data = list(train_data), list(test_data)
self.train_dataset = SFTDataset(
data=train_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
self.test_dataset = SFTDataset(
data=test_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def download_if_missing(file_path: Path, file_url: str, mode: str = "w", stream: bool = False) -> None:
"""Downloads the raw json data file and saves it in the given destination."""
if file_path.exists() and file_path.stat().st_size > 0:
return
if not _REQUESTS_AVAILABLE:
raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE))
import requests
response = requests.get(file_url, stream=stream)
with open(file_path, mode, encoding=None if mode == "wb" else "utf-8") as f:
if stream:
# credit: https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py#L25-L38
from tqdm import tqdm
pbar = tqdm(
desc=str(file_path),
total=int(response.headers.get("content-length", 0)),
unit="iB",
unit_scale=True,
unit_divisor=1024,
)
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
pbar.update(size)
pbar.close()
else:
f.write(response.text)
================================================
FILE: litgpt/data/alpaca_2k.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from litgpt.data.alpaca import Alpaca
from litgpt.data.base import SFTDataset
@dataclass
class Alpaca2k(Alpaca):
"""Alpaca2k data module for supervised finetuning."""
val_split_fraction: float = 0.05 # to get exactly 100 validation samples,
"""The fraction of the dataset to use for the validation dataset. The rest is used for training."""
download_dir: Path = Path("./data/alpaca2k")
"""The directory in which the downloaded datasetgets saved."""
repo_id: str = field(repr=False, default="mhenrichsen/alpaca_2k_test")
"""The URL from where to download the dataset."""
file_name: str = field(repr=False, default="alpaca2k_data_cleaned_archive.json")
"""The name of the dataset file to download."""
def prepare_data(self) -> None:
from datasets import load_dataset
load_dataset(self.repo_id, cache_dir=self.download_dir)
def setup(self, stage: str = "") -> None:
from datasets import load_dataset
dataset = load_dataset(self.repo_id, cache_dir=self.download_dir)
train_validation_split = dataset["train"].train_test_split(test_size=self.val_split_fraction, seed=self.seed)
train_data = train_validation_split["train"]
test_data = train_validation_split["test"]
self.train_dataset = SFTDataset(
data=train_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
self.test_dataset = SFTDataset(
data=test_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
================================================
FILE: litgpt/data/alpaca_gpt4.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from litgpt.data.alpaca import Alpaca
_URL = "https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json"
@dataclass
class AlpacaGPT4(Alpaca):
"""AlpacaGPT4 data module for supervised finetuning."""
val_split_fraction: float = 0.03847 # to get exactly 2000 test samples,
"""The fraction of the dataset to use for the validation dataset. The rest is used for training."""
download_dir: Path = Path("./data/alpacagpt4")
"""The directory in which the downloaded datasetgets saved."""
file_url: str = field(repr=False, default=_URL)
"""The URL from where to download the dataset."""
file_name: str = field(repr=False, default="alpacagpt4_data_cleaned_archive.json")
"""The name of the dataset file to download."""
================================================
FILE: litgpt/data/base.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from abc import abstractmethod
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import Dataset
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
class DataModule(LightningDataModule):
"""Base class for all data modules in LitGPT."""
@abstractmethod
def connect(
self,
tokenizer: Optional[Tokenizer] = None,
batch_size: int = 1,
max_seq_length: Optional[int] = None,
**kwargs,
) -> None:
"""All settings that can't be determined at the time of instantiation need to be passed through here
before any dataloaders can be accessed.
"""
def setup(self, stage: str = "") -> None:
# Stub is to redefine the default signature, because the concept of 'stage' does not exist in LitGPT
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class SFTDataset(Dataset):
"""An in-memory dataset for supervised finetuning with `input_ids` and `labels`.
Args:
data: A list of samples (dicts). The target/label must be stored under the key 'output' and the instruction
or other data can be stored under any key as long as it is compatible with the given prompt template.
tokenizer: The tokenizer to use. Should match the one that was used to pretrain the model.
prompt_style: The style to apply to prompts. See `litgpt.prompts` for a list of available styles.
max_seq_length: Truncate sequences that are longer than this value. By default, no truncation is applied.
mask_prompt: Whether to mask the prompt section from the label (with ``ignore_index``).
ignore_index: The index to use for elements to be ignored in the label.
transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the
keys in the dataset to the expected 'instruction' and 'output' keys.
Returns a dict with two keys:
input_ids: The encoded prompt + response
labels: Same as input_ids, unless ``mask_prompt=True`` in which case the 'prompt' part is replaced with
the ``ignore_index``.
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: Tokenizer,
prompt_style: Union[str, PromptStyle],
max_seq_length: int = -1,
mask_prompt: bool = True,
ignore_index: int = -100,
transform: Optional[Callable[[Any], Any]] = None,
) -> None:
self.data = data
self.tokenizer = tokenizer
self.prompt_style = (
prompt_style if isinstance(prompt_style, PromptStyle) else PromptStyle.from_name(prompt_style)
)
self.max_seq_length = max_seq_length
self.mask_prompt = mask_prompt
self.ignore_index = ignore_index
self.transform = transform
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Union[Tensor, Dict[str, int]]]:
example = self.data[idx]
if self.transform is not None:
example = self.transform(example)
prompt = self.prompt_style.apply(prompt=example["instruction"], **example)
encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length)
encoded_response = self.tokenizer.encode(example["output"], bos=False, eos=True, max_length=self.max_seq_length)
encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)
if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1
encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length]
# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_prompt_and_response.clone()
if self.mask_prompt:
labels[: len(encoded_prompt)] = self.ignore_index
raw_token_count = len(self.tokenizer.encode(example["instruction"], max_length=self.max_seq_length)) + len(
encoded_response
)
return {
"input_ids": encoded_prompt_and_response,
"labels": labels,
"token_counts": {
"raw": raw_token_count,
"raw_plus_prompt_template": len(encoded_prompt_and_response),
},
}
def get_sft_collate_fn(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100):
"""Returns the collate function for supervised finetuning (needed in the DataLoader).
The collate function gets a list of dicts with keys `input_ids` and `labels`.
It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in
the batch. Optionally truncates all sequences to the specified maximum length.
"""
return partial(_sft_collate_fn, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index)
def _sft_collate_fn(
samples: List[Dict[str, Tensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100
) -> Dict[str, Tensor]:
batched = {}
for key in ("input_ids", "labels"):
pad_value = pad_id if key == "input_ids" else ignore_index
# Pad right based on the longest sequence
batched[key] = torch.nn.utils.rnn.pad_sequence(
[sample[key] for sample in samples], batch_first=True, padding_value=pad_value
)
# Truncate if needed
if max_seq_length > 0:
batched[key] = batched[key][:, :max_seq_length]
batched["token_counts"] = {}
batched["token_counts"]["raw"] = torch.tensor( # Token count without padding and without prompt template
[sample["token_counts"]["raw"] for sample in samples], dtype=torch.int64
).unsqueeze(1)
batched["token_counts"]["raw_plus_prompt_template"] = (
torch.tensor( # Token count without padding but with prompt template
[sample["token_counts"]["raw_plus_prompt_template"] for sample in samples], dtype=torch.int64
).unsqueeze(1)
)
return batched
================================================
FILE: litgpt/data/deita.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union
import torch
from torch.utils.data import DataLoader
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
@dataclass
class Deita(DataModule):
"""Deita data module for supervised finetuning."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
include_multiturn_conversations: bool = False
"""Whether to include multi-turn conversations in the dataset."""
download_dir: Path = Path("./data/deita")
"""The directory in which the downloaded dataset gets saved."""
repo_id: str = "HuggingFaceH4/deita-10k-v0-sft"
"""The repo from where the data is downloaded"""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
super().__init__()
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def prepare_data(self) -> None:
from datasets import load_dataset
load_dataset(self.repo_id, split=["train_sft", "test_sft"], cache_dir=self.download_dir)
def setup(self, stage: str = "") -> None:
from datasets import load_dataset
dataset = load_dataset(self.repo_id, split=["train_sft", "test_sft"])
train_data = format_dataset(dataset[0], self.include_multiturn_conversations)
test_data = format_dataset(dataset[1], self.include_multiturn_conversations)
self.train_dataset = SFTDataset(
data=train_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
self.test_dataset = SFTDataset(
data=test_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def format_dataset(dataset: List[dict], include_multi_turn_conversations: bool) -> List[dict]:
formatted = []
for entry in dataset:
convo = entry["messages"]
if include_multi_turn_conversations:
for i in range(0, len(convo) - 1, 2):
formatted.append({"instruction": convo[i]["content"], "input": "", "output": convo[i + 1]["content"]})
else:
formatted.append({"instruction": convo[0]["content"], "input": "", "output": convo[1]["content"]})
return formatted
================================================
FILE: litgpt/data/flan.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
import torch
from torch.utils.data import DataLoader
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.alpaca import download_if_missing
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
_URL = "https://huggingface.co/datasets/Muennighoff/flan/resolve/main"
# TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache
# on disk or use Lightning Data
@dataclass
class FLAN(DataModule):
"""FLAN data module for supervised finetuning."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
prompt_style: Union[str, PromptStyle] = "flan"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/flan")
"""The directory in which the downloaded dataset gets saved."""
url: str = _URL
"""The URL from where to download the dataset."""
subsets: Optional[str] = None
"""A comma separated list of subsets to use. If None, all subsets are used."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self):
super().__init__()
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
supported_subsets = _supported_subsets()
if self.subsets is not None:
self.subsets = self.subsets.split(",")
for subset in self.subsets:
if subset not in supported_subsets:
raise ValueError(f"{subset} not in {supported_subsets}")
else:
self.subsets = list(supported_subsets)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def prepare_data(self) -> None:
self.download_dir.mkdir(parents=True, exist_ok=True)
for subset in self.subsets:
for split in ("train", "test"):
data_file_path = self.download_dir / f"{subset}_{split}.jsonl"
data_file_url = f"{self.url}/{split}/{subset}_{split}.jsonl"
download_if_missing(data_file_path, data_file_url)
def train_dataloader(self):
return self._dataloader("train")
def val_dataloader(self):
return self._dataloader("test")
def _dataloader(self, split: str) -> DataLoader:
data = []
for subset in self.subsets:
data_file_path = self.download_dir / f"{subset}_{split}.jsonl"
data.extend(load_jsonl(data_file_path))
dataset = SFTDataset(
data=data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
transform=_transform,
)
return DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=(split == "train"),
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def load_jsonl(filename: Path) -> List[Dict[str, str]]:
data = []
with open(filename, encoding="utf-8") as f:
for line in f:
data.append(json.loads(line))
return data
def _transform(item: dict) -> dict:
item["instruction"] = item.pop("inputs")
item["output"] = item.pop("targets")
return item
def _supported_subsets() -> Set[str]:
return {
"aeslc_10templates",
"ag_news_subset_10templates",
"anli_r1_10templates",
"anli_r2_10templates",
"anli_r3_10templates",
"arc_challenge_10templates",
"arc_easy_10templates",
"bool_q_10templates",
"cb_10templates",
"cnn_dailymail_10templates",
"cola_10templates",
"common_gen_10templates",
"copa_10templates",
"coqa_10templates",
"cosmos_qa_10templates",
"dart_10templates",
"definite_pronoun_resolution_10templates",
"drop_10templates",
"e2e_nlg_10templates",
"fix_punct_10templates",
"gigaword_10templates",
"glue_mrpc_10templates",
"glue_qqp_10templates",
"hellaswag_10templates",
"imdb_reviews_10templates",
"math_dataset_10templates",
"mnli_matched_10templates",
"mnli_mismatched_10templates",
"multi_news_10templates",
"multirc_10templates",
"natural_questions_10templates",
"openbookqa_10templates",
"opinion_abstracts_idebate_10templates",
"opinion_abstracts_rotten_tomatoes_10templates",
"para_crawl_enes_10templates",
"paws_wiki_10templates",
"piqa_10templates",
"qnli_10templates",
"quac_10templates",
"record_10templates",
"rte_10templates",
"samsum_10templates",
"sentiment140_10templates",
"snli_10templates",
"squad_v1_10templates",
"squad_v2_10templates",
"sst2_10templates",
"story_cloze_10templates",
"stsb_10templates",
"trec_10templates",
"trivia_qa_10templates",
"true_case_10templates",
"web_nlg_en_10templates",
"wic_10templates",
"wiki_lingua_english_en_10templates",
"wmt14_enfr_10templates",
"wmt16_translate_csen_10templates",
"wmt16_translate_deen_10templates",
"wmt16_translate_fien_10templates",
"wmt16_translate_roen_10templates",
"wmt16_translate_ruen_10templates",
"wmt16_translate_tren_10templates",
"wnli_10templates",
"word_segment_10templates",
"wsc_10templates",
"yelp_polarity_reviews_10templates",
}
================================================
FILE: litgpt/data/json_data.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Tuple, Union
import torch
from torch.utils.data import DataLoader, random_split
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
@dataclass
class JSON(DataModule):
"""Loads JSON or JSONL data for supervised finetuning."""
json_path: Path
"""A path to a JSON file or a directory with `train.json` and `val.json` containing the data.
The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output',
and can optionally have a key 'input' (see Alpaca)."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
val_split_fraction: Optional[float] = None
"""The fraction of the dataset to use for the validation dataset. The rest is used for training.
Only applies if you passed in a single file to `json_path`."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self):
super().__init__()
if self.json_path.is_file() and self.val_split_fraction is None:
self.val_split_fraction = 0.05
warnings.warn(
"The `json_path` points to a single file and `val_split_fraction` was not set. "
"Defaulting to `val_split_fraction=0.05`. Set `val_split_fraction` explicitly "
"to use a different split percentage.",
UserWarning,
stacklevel=2,
)
if self.json_path.is_dir() and self.val_split_fraction is not None:
raise ValueError(
"If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and"
f" hence `val_split_fraction` should not be set. Got `{self.val_split_fraction=}`."
)
if not self.json_path.exists():
raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files,"
f" but '{self.json_path!s}' does not exist."
)
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def setup(self, stage: str = "") -> None:
train_data, test_data = self.get_splits()
self.train_dataset = SFTDataset(
data=train_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
self.test_dataset = SFTDataset(
data=test_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def get_splits(self) -> Tuple:
# A single file (gets split into train and test)
if self.json_path.is_file():
data = load_split(self.json_path)
# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.val_split_fraction, self.val_split_fraction],
generator=torch.Generator().manual_seed(self.seed),
)
return train_data, test_data
# A directory containing train.json and val.json
if (train_file := self.find_split("train")) and (val_file := self.find_split("val")):
train_data = load_split(train_file)
test_data = load_split(val_file)
return train_data, test_data
raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files."
)
def find_split(self, split_name: str) -> Optional[Path]:
for suffix in (".json", ".jsonl"):
if (file := self.json_path / f"{split_name}{suffix}").is_file():
return file
return None
def load_split(json_path: Path) -> Any:
if json_path.suffix == ".json":
with open(json_path, encoding="utf-8") as file:
return json.load(file)
if json_path.suffix == ".jsonl":
with open(json_path, encoding="utf-8") as file:
return [json.loads(line) for line in file]
else:
raise ValueError(f"Unsupported file format: {json_path.suffix}. Expected `.json` or `.jsonl`.")
================================================
FILE: litgpt/data/lima.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
import os
from dataclasses import dataclass, field
from typing import List, Optional, Union
import torch
from torch.utils.data import DataLoader, random_split
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
@dataclass
class LIMA(DataModule):
"""LIMA data module for supervised finetuning."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
val_split_fraction: float = 0.1
"""The fraction of the dataset to use for the validation dataset. The rest is used for training."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
include_multiturn_conversations: bool = False
"""Whether to include multi-turn conversations in the dataset."""
repo_id: str = "GAIR/lima"
"""The Hugging Face dataset repository ID from where to download the data."""
access_token: Optional[str] = field(repr=False, default=os.getenv("HF_TOKEN"))
"""The Hugging Face API token to use for authentication. Can also be set through the
`HF_TOKEN` environment variable."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self):
super().__init__()
if self.access_token is None:
raise ValueError(
"LIMA requires authentication, please set the `HF_TOKEN=your_token` environment"
" variable or pass --access_token=your_token. You can find your token by visiting"
" https://huggingface.co/settings/tokens"
)
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def prepare_data(self) -> None:
from datasets import load_dataset
load_dataset(self.repo_id, token=self.access_token)
def setup(self, stage: str = "") -> None:
from datasets import load_dataset
dataset = load_dataset(self.repo_id, token=self.access_token)
data = format_dataset(dataset["train"], self.include_multiturn_conversations)
# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.val_split_fraction, self.val_split_fraction],
generator=torch.Generator().manual_seed(self.seed),
)
train_data, test_data = list(train_data), list(test_data)
self.train_dataset = SFTDataset(
data=train_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
self.test_dataset = SFTDataset(
data=test_data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def format_dataset(dataset_partition: dict, include_multi_turn_conversations: bool) -> List[dict]:
formatted_ds = []
for entry in dataset_partition:
convo = entry["conversations"]
if include_multi_turn_conversations:
for i in range(0, len(convo) - 1, 2):
formatted_ds.append({"instruction": convo[i], "input": "", "output": convo[i + 1]})
else:
formatted_ds.append({"instruction": convo[0], "input": "", "output": convo[1]})
return formatted_ds
================================================
FILE: litgpt/data/lit_data.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple, Union
from torch.utils.data import DataLoader
from litgpt.data import DataModule
from litgpt.tokenizer import Tokenizer
@dataclass
class LitData(DataModule):
"""Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks)."""
data_path: Union[str, Path] = Path("data/")
"""The path to the data directory containing the preprocessed chunks for the streaming dataset
The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders
for training- and validation splits."""
split_names: Optional[Tuple[str, str]] = None
"""Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided,
all data under data_path will be used for training, and the validation dataloader will be identical to the
train dataloader."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""
batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)
def __post_init__(self) -> None:
super().__init__()
if self.split_names is not None and len(self.split_names) != 2:
raise ValueError("If provided `split_names` must be a tuple of two strings, for example: ('train', 'val').")
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
def train_dataloader(self) -> DataLoader:
input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path)
return self._dataloader(input_dir=input_dir, train=True)
def val_dataloader(self) -> DataLoader:
input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path)
return self._dataloader(input_dir=input_dir, train=False)
def _dataloader(self, input_dir: str, train: bool):
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
dataset = StreamingDataset(
input_dir=input_dir,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=train,
seed=self.seed,
)
dataloader = StreamingDataLoader(
dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return dataloader
================================================
FILE: litgpt/data/longform.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
import torch
from torch.utils.data import DataLoader
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.alpaca import download_if_missing
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
_URL = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset"
@dataclass
class LongForm(DataModule):
"""LongForm data module for supervised finetuning."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
prompt_style: Union[str, PromptStyle] = "longform"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -100
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/longform")
"""The directory in which the downloaded dataset gets saved."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
super().__init__()
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = -1 if max_seq_length is None else max_seq_length
def prepare_data(self) -> None:
self.download_dir.mkdir(parents=True, exist_ok=True)
download_if_missing(self.download_dir / "train.json", f"{_URL}/train.json")
download_if_missing(self.download_dir / "val.json", f"{_URL}/val.json")
def train_dataloader(self):
return self._dataloader("train")
def val_dataloader(self):
return self._dataloader("val")
def _dataloader(self, split: str) -> DataLoader:
with open(self.download_dir / f"{split}.json", encoding="utf-8") as file:
data = json.load(file)
dataset = SFTDataset(
data=data,
tokenizer=self.tokenizer,
prompt_style=self.prompt_style,
max_seq_length=self.max_seq_length,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
transform=_transform,
)
return DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=(split == "train"),
generator=torch.Generator().manual_seed(self.seed),
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
)
def _transform(item: dict) -> dict:
item["instruction"] = item.pop("input")
return item
================================================
FILE: litgpt/data/microllama.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass
from pathlib import Path
from typing import Union
from litgpt.data.tinyllama import TinyLlama
@dataclass
class MicroLlama(TinyLlama):
"""The MicroLlama data module is composed of only SlimPajama data."""
def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8):
super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False)
================================================
FILE: litgpt/data/openwebtext.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
from torch.utils.data import DataLoader
from litgpt.data import DataModule
from litgpt.tokenizer import Tokenizer
@dataclass
class OpenWebText(DataModule):
"""The OpenWebText data module for pretraining."""
data_path: Union[str, Path] = Path("data/openwebtext")
"""The path to the data directory, containing two folders 'train' and 'val'
which are the output of the preprocessing step. The path can also be a remote path (e.g., s3://)."""
val_split_fraction: float = 0.0005
"""The fraction of data that should be put aside for validation."""
seed: int = 42
"""The seed to use for shuffling the training data."""
num_workers: int = 8
"""The number of workers to use for the dataloaders."""
tokenizer: Optional[Tokenizer] = field(default=None, repr=False, init=False)
batch_size: int = field(default=1, repr=False, init=False)
seq_length: int = field(default=2048, repr=False, init=False)
def __post_init__(self) -> None:
super().__init__()
# Could be a remote path (s3://) or a local path
self.data_path_train = str(self.data_path).rstrip("/") + "/train"
self.data_path_val = str(self.data_path).rstrip("/") + "/val"
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = 2048
) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
def prepare_data(self) -> None:
from datasets import Dataset, load_dataset
from litdata import optimize
if str(self.data_path).startswith("s3://"):
print(f"The OpenWebText data path points to an S3 location: {self.data_path}. Skipping preprocessing.")
return
if Path(self.data_path_train).is_dir() and Path(self.data_path_val).is_dir():
print(f"Found OpenWebText train and val dir: {self.data_path}. Skipping preprocessing.")
return
dataset = load_dataset("openwebtext", num_proc=(os.cpu_count() // 2), trust_remote_code=True)
# Split the data in training and validation
split_dataset = dataset["train"].train_test_split(
test_size=self.val_split_fraction, seed=self.seed, shuffle=True
)
split_dataset["val"] = split_dataset.pop("test") # rename the test split to val
def tokenize(data: Dataset, index: int):
yield self.tokenizer.encode(data[index]["text"], eos=True)
optimize(
fn=partial(tokenize, split_dataset["train"]),
inputs=list(range(len(split_dataset["train"]))),
output_dir=self.data_path_train,
num_workers=min(64, os.cpu_count() - 1),
chunk_bytes="200MB",
)
optimize(
fn=partial(tokenize, split_dataset["val"]),
inputs=list(range(len(split_dataset["val"]))),
output_dir=self.data_path_val,
num_workers=min(8, os.cpu_count() - 1),
chunk_bytes="200MB",
)
def train_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
train_dataset = StreamingDataset(
input_dir=self.data_path_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
)
train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader
def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
val_dataset = StreamingDataset(
input_dir=self.data_path_val,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
)
val_dataloader = StreamingDataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
================================================
FILE: litgpt/data/prepare_slimpajama.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
import os
import time
from pathlib import Path
from litgpt.data.prepare_starcoder import DataChunkRecipe
from litgpt.tokenizer import Tokenizer
from litgpt.utils import CLI, extend_checkpoint_dir
class SlimPajamaDataRecipe(DataChunkRecipe):
is_generator = True
def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
def prepare_structure(self, input_dir):
files = Path(input_dir).rglob("*.zst")
return [str(file) for file in files]
def prepare_item(self, filepath):
import zstandard as zstd
with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
for row in f:
text = json.loads(row)["text"]
if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub":
continue # exclude the GitHub data since it overlaps with starcoder
text_ids = self.tokenizer.encode(string=text, bos=False, eos=True)
yield text_ids
def prepare(
input_dir: Path = Path("data/SlimPajama-627B/train"),
output_dir: Path = Path("data/slimpajama/train"),
tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"),
chunk_size: int = (2049 * 16384),
fast_dev_run: bool = False,
) -> None:
from litdata.processing.data_processor import DataProcessor
tokenizer_path = extend_checkpoint_dir(tokenizer_path)
tokenizer = Tokenizer(tokenizer_path)
data_recipe = SlimPajamaDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)
data_processor = DataProcessor(
input_dir=str(input_dir),
output_dir=str(output_dir),
fast_dev_run=fast_dev_run,
num_workers=os.cpu_count(),
num_downloaders=1,
)
start_time = time.time()
data_processor.run(data_recipe)
elapsed_time = time.time() - start_time
print(f"Time taken: {elapsed_time:.2f} seconds")
if __name__ == "__main__":
CLI(prepare)
================================================
FILE: litgpt/data/prepare_starcoder.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import time
import traceback
from pathlib import Path
from litgpt.constants import _LITDATA_AVAILABLE
from litgpt.tokenizer import Tokenizer
from litgpt.utils import CLI, extend_checkpoint_dir
if _LITDATA_AVAILABLE:
from litdata.processing.data_processor import DataChunkRecipe
else:
DataChunkRecipe = object
class StarcoderDataRecipe(DataChunkRecipe):
is_generator = True
def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
def prepare_structure(self, input_dir):
files = Path(input_dir).rglob("*.parquet")
return [str(file) for file in files]
def prepare_item(self, item_metadata):
import pyarrow.parquet as pq
filepath = item_metadata
start = time.time()
try:
parquet_file = pq.ParquetFile(filepath)
# reduce RAM usage
for batch in parquet_file.iter_batches(batch_size=8192, columns=["content"]):
for text in batch.to_pandas()["content"]:
yield self.tokenizer.encode(text, bos=False, eos=True)
except Exception:
print(traceback.format_exc())
print(f"Error reading {filepath}")
return
parquet_file.close()
end = time.time()
print(f"Took {end - start:.2f} seconds total", filepath)
def prepare(
input_dir: Path = Path("data/starcoderdata"),
output_dir: Path = Path("data/starcoder"),
tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"),
chunk_size: int = (2049 * 8192),
fast_dev_run: bool = False,
) -> None:
from litdata.processing.data_processor import DataProcessor
tokenizer_path = extend_checkpoint_dir(tokenizer_path)
tokenizer = Tokenizer(tokenizer_path)
data_recipe = StarcoderDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)
data_processor = DataProcessor(
input_dir=str(input_dir),
output_dir=str(output_dir),
fast_dev_run=fast_dev_run,
num_workers=os.cpu_count(),
num_downloaders=1,
)
start_time = time.time()
data_processor.run(data_recipe)
elapsed_time = time.time() - start_time
print(f"Time taken: {elapsed_time:.2f} seconds")
if __name__ == "__main__":
CLI(prepare)
================================================
FILE: litgpt/data/text_files.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import glob
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
from torch.utils.data import DataLoader
from litgpt.data import DataModule
from litgpt.tokenizer import Tokenizer
@dataclass
class TextFiles(DataModule):
"""The TextFile data module used for pretraining.
Reads in text data from plaintext files contained in a data folder
and provides training and validation dataloaders that return batches of tokens.
Every sample is set to a fixed length.
"""
train_data_path: Path
"""The path to the data directory used for training that contains .txt files"""
val_data_path: Optional[Path] = None
"""The path to the data directory used for validation that
contains .txt files. Splits off data for validation from the
training set if None."""
seed: int = 42
"""The seed to use for shuffling the dataset."""
num_workers: int = 4
"""The number of workers to use for data loading."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
def __post_init__(self) -> None:
super().__init__()
self.out_path_train = self.train_data_path / "train"
if self.val_data_path is None:
self.out_path_val = self.train_data_path / "val"
else:
self.out_path_val = Path(self.val_data_path) / "val"
def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
def prepare_data(self) -> None:
from litdata import optimize
from litdata.streaming import TokensLoader
train_files = sorted(glob.glob(str(self.train_data_path / "*.txt")))
assert len(train_files) > 0, f"No .txt files found in train data {train_files}"
if self.val_data_path is not None:
self.val_data_path = Path(self.val_data_path)
val_files = sorted(glob.glob(str(self.val_data_path / "*.txt")))
assert len(val_files) > 0, f"No .txt files found in validation data {val_files}"
# train/test split. let's use only shard 0 for test split, rest train
else:
assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}"
val_files, *train_files = train_files
val_files = [val_files]
# It's ok to use almost all CPUs here because this runs in a single process
num_workers = os.cpu_count() - 1
use_workers = min(num_workers, len(train_files))
if not Path(self.out_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
output_dir=str(self.out_path_train),
num_workers=use_workers,
chunk_bytes="50MB",
item_loader=TokensLoader(block_size=self.max_seq_length),
)
else:
print(
f"\nWarning: Preprocessed training data found in {self.out_path_train}."
" For efficiency, reprocessing is skipped. If your text input has changed since"
" the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger"
f" reprocessing: `rm -rf {self.out_path_train}`\n"
)
use_workers = min(num_workers, len(val_files))
if not Path(self.out_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=val_files,
output_dir=str(self.out_path_val),
num_workers=use_workers,
chunk_bytes="50MB",
item_loader=TokensLoader(block_size=self.max_seq_length),
)
else:
print(
f"\nWarning: Preprocessed validation data found in {self.out_path_val}."
" For efficiency, reprocessing is skipped. If your text input has changed since"
" the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger"
f" reprocessing: `rm -rf {self.out_path_val}`\n"
)
def train_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
train_dataset = StreamingDataset(
input_dir=str(self.out_path_train),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
)
train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader
def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
val_dataset = StreamingDataset(
input_dir=str(self.out_path_val),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
)
val_dataloader = StreamingDataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
def tokenize(filename: str, tokenizer: Tokenizer):
with open(filename, encoding="utf-8") as file:
text = file.read()
text = text.strip()
yield tokenizer.encode(text, bos=True, eos=False)
def validate_tokenizer(tokenizer: Tokenizer) -> None:
if tokenizer is None:
raise ValueError(
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
"please provide a valid `--tokenizer_dir` path."
)
================================================
FILE: litgpt/data/tinyllama.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
from torch.utils.data import DataLoader
from litgpt.data import DataModule
from litgpt.tokenizer import Tokenizer
@dataclass
class TinyLlama(DataModule):
"""The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data.
Provides training and validation streaming dataloaders that return batches of tokens.
"""
data_path: Union[str, Path] = Path("data/")
"""The path to the data directory, containing two folders 'slimpajama' and 'starcoder'
which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md`
for instructions. The path can also be a remote path (e.g., s3://)."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""
use_starcoder: bool = True
"""Toggle for using Starcoder data."""
batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)
def __post_init__(self):
super().__init__()
# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"
self.required_paths = [self.slimpajama_train, self.slimpajama_val]
if self.use_starcoder:
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths += [self.starcoder_train]
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
def prepare_data(self) -> None:
for path in self.required_paths:
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
f" `slimpajama/train`, `slimpajama/val`, `starcoder`. The directory {path} does not exist."
" Set it via `--data.data_path=...`"
)
def train_dataloader(self) -> DataLoader:
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader
slim_train_data = StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
)
train_data = slim_train_data
if self.use_starcoder:
train_datasets = [
slim_train_data,
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]
# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
train_data = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
)
train_dataloader = StreamingDataLoader(
train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader
def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
val_dataset = StreamingDataset(
input_dir=self.slimpajama_val,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
)
val_dataloader = StreamingDataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
================================================
FILE: litgpt/data/tinystories.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import glob
import json
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
from torch.utils.data import DataLoader
from tqdm import tqdm
from litgpt.data import DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.data.text_files import validate_tokenizer
from litgpt.tokenizer import Tokenizer
@dataclass
class TinyStories(DataModule):
"""The TinyStories data module: https://huggingface.co/datasets/roneneldan/TinyStories
Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length.
"""
data_path: Path = Path("data/tinystories")
"""The path to the data directory, containing two folders 'train' and 'val'
which are the output of the preprocessing step."""
seed: int = 42
"""The seed to use for shuffling the dataset."""
num_workers: int = 8
"""The number of workers to use for the dataloaders."""
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
def __post_init__(self) -> None:
super().__init__()
self.data_path_train = self.data_path / "train"
self.data_path_val = self.data_path / "val"
def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
def prepare_data(self) -> None:
from litdata import TokensLoader, optimize
download(self.data_path)
files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json")))
assert len(files) > 0, f"No json files found in {files}"
assert len(files) > 1, f"Expected at least two json files in {files}"
# train/test split. let's use only shard 0 for test split, rest train
val_file, *train_files = files
num_workers = os.cpu_count() - 1
if not Path(self.data_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
output_dir=str(self.data_path_train),
num_workers=num_workers,
chunk_bytes="200MB",
item_loader=TokensLoader(),
)
if not Path(self.data_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=[val_file],
output_dir=str(self.data_path_val),
num_workers=1, # there's only 1 file
chunk_bytes="200MB",
item_loader=TokensLoader(),
)
def train_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
train_dataset = StreamingDataset(
input_dir=str(self.data_path_train),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
)
train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader
def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader
val_dataset = StreamingDataset(
input_dir=str(self.data_path_val),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
)
val_dataloader = StreamingDataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
def tokenize(filename: str, tokenizer: Tokenizer):
with open(filename, encoding="utf-8") as f:
data = json.load(f)
global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"])
num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"])
local_rank = global_rank % num_workers
for example in tqdm(data, position=local_rank):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS
yield tokens
_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
def download(data_dir: Path):
data_dir.mkdir(exist_ok=True, parents=True)
data_tar = data_dir / "TinyStories_all_data.tar.gz"
data_dir = data_dir / "TinyStories_all_data"
shard_filenames = sorted(glob.glob(str(data_dir / "*.json")))
if shard_filenames:
print(f"{data_dir} already exists, skipping unpacking...")
return
# download the TinyStories dataset, unless it's already downloaded
download_if_missing(data_tar, _URL, stream=True, mode="wb")
# unpack the tar.gz file into all the data shards (json files)
data_dir.mkdir(exist_ok=False)
tar_command = f"tar -xzf {data_tar} -C {data_dir}"
print(tar_command)
os.system(tar_command)
shard_filenames = sorted(glob.glob(str(data_dir / "*.json")))
print(f"Number of shards: {len(shard_filenames)}")
================================================
FILE: litgpt/deploy/__init__.py
================================================
================================================
FILE: litgpt/deploy/serve.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
import sys
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, Literal, Optional
import torch
from litgpt.api import LLM
from litgpt.constants import _JINJA2_AVAILABLE, _LITSERVE_AVAILABLE
from litgpt.utils import auto_download_checkpoint
if _LITSERVE_AVAILABLE:
from litserve import LitAPI, LitServer
from litserve.specs.openai import ChatCompletionRequest, OpenAISpec
else:
LitAPI, LitServer = object, object
class BaseLitAPI(LitAPI):
def __init__(
self,
checkpoint_dir: Path,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
) -> None:
if not _LITSERVE_AVAILABLE:
raise ImportError(str(_LITSERVE_AVAILABLE))
super().__init__(api_path=api_path)
self.checkpoint_dir = checkpoint_dir
self.quantize = quantize
self.precision = precision
self.temperature = temperature
self.top_k = top_k
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.devices = devices
self.generate_strategy = generate_strategy
def setup(self, device: str) -> None:
if ":" in device:
accelerator, device = device.split(":")
device = f"[{int(device)}]"
else:
accelerator = device
device = 1
print("Initializing model...", file=sys.stderr)
self.llm = LLM.load(model=self.checkpoint_dir, distribute=None)
self.llm.distribute(
devices=self.devices,
accelerator=accelerator,
quantize=self.quantize,
precision=self.precision,
generate_strategy=self.generate_strategy
or ("sequential" if self.devices is not None and self.devices > 1 else None),
)
print("Model successfully initialized.", file=sys.stderr)
def decode_request(self, request: Dict[str, Any]) -> Any:
prompt = str(request["prompt"])
return prompt
class SimpleLitAPI(BaseLitAPI):
def __init__(
self,
checkpoint_dir: Path,
quantize: Optional[str] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
generate_strategy: Optional[str] = None,
):
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
generate_strategy=generate_strategy,
)
def setup(self, device: str):
super().setup(device)
def predict(self, inputs: str) -> Any:
output = self.llm.generate(
inputs,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
max_new_tokens=self.max_new_tokens,
)
return output
def encode_response(self, output: str) -> Dict[str, Any]:
# Convert the model output to a response payload.
return {"output": output}
class StreamLitAPI(BaseLitAPI):
def __init__(
self,
checkpoint_dir: Path,
quantize: Optional[str] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
generate_strategy: Optional[str] = None,
):
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
generate_strategy=generate_strategy,
)
def setup(self, device: str):
super().setup(device)
def predict(self, inputs: torch.Tensor) -> Any:
yield from self.llm.generate(
inputs,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
max_new_tokens=self.max_new_tokens,
stream=True,
)
def encode_response(self, output):
for out in output:
yield {"output": out}
class OpenAISpecLitAPI(BaseLitAPI):
def __init__(
self,
checkpoint_dir: Path,
quantize: Optional[str] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
api_path: Optional[str] = None,
generate_strategy: Optional[str] = None,
):
super().__init__(
checkpoint_dir,
quantize,
precision,
temperature,
top_k,
top_p,
max_new_tokens,
devices,
api_path=api_path,
generate_strategy=generate_strategy,
)
def setup(self, device: str):
super().setup(device)
if not _JINJA2_AVAILABLE:
raise ImportError(str(_JINJA2_AVAILABLE))
from jinja2 import Template
config_path = self.checkpoint_dir / "tokenizer_config.json"
if not config_path.is_file():
raise FileNotFoundError(f"Tokenizer config file not found at {config_path}")
with open(config_path, encoding="utf-8") as fp:
config = json.load(fp)
chat_template = config.get("chat_template", None)
if chat_template is None:
print("The tokenizer config does not contain chat_template, falling back to a default.")
chat_template = "{% for m in messages %}{{ m.role }}: {{ m.content }}\n{% endfor %}Assistant: "
self.chat_template = chat_template
self.template = Template(self.chat_template)
def decode_request(self, request: "ChatCompletionRequest") -> Any:
# Apply chat template to request messages
return self.template.render(messages=request.messages)
def predict(self, inputs: str, context: dict) -> Any:
# Extract parameters from context with fallback to instance attributes
temperature = context.get("temperature") or self.temperature
top_p = context.get("top_p", self.top_p) or self.top_p
max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens
# Run the model on the input and return the output.
yield from self.llm.generate(
inputs,
temperature=temperature,
top_k=self.top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
stream=True,
)
def run_server(
checkpoint_dir: Path,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "auto",
port: int = 8000,
stream: bool = False,
openai_spec: bool = False,
access_token: Optional[str] = None,
api_path: Optional[str] = "/predict",
timeout: int = 30,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
) -> None:
"""Serve a LitGPT model using LitServe.
Evaluate a model with the LM Evaluation Harness.
Arguments:
checkpoint_dir: The checkpoint directory to load the model from.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Optional precision setting to instantiate the model weights in. By default, this will
automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
temperature: Temperature setting for the text generation. Value above 1 increase randomness.
Values below 1 decrease randomness.
top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel
generated text but can also lead to more incoherent texts.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
max_new_tokens: The number of generation steps to take.
devices: How many devices/GPUs to use.
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
openai_spec: Whether to use the OpenAISpec and enable OpenAI-compatible API endpoints. When True, the server will provide
`/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients,
making it easy to integrate with existing applications that use the OpenAI API.
access_token: Optional API token to access models with restrictions.
api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
timeout: Request timeout in seconds. Defaults to 30.
generate_strategy: The generation strategy to use. The "sequential" strategy (default for devices > 1)
allows running models that wouldn't fit in a single card by partitioning the transformer blocks across
all devices and running them sequentially. "tensor_parallel" shards the model using tensor parallelism.
If None (default for devices = 1), the model is not distributed.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI
server = LitServer(
api_class(
checkpoint_dir=checkpoint_dir,
quantize=quantize,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
devices=devices,
api_path=api_path,
generate_strategy=generate_strategy,
),
spec=OpenAISpec() if openai_spec else None,
accelerator=accelerator,
devices=1,
stream=stream,
timeout=timeout,
)
server.run(port=port, generate_client_file=False)
================================================
FILE: litgpt/eval/evaluate.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
import os
from pathlib import Path
from pprint import pprint
from typing import Optional, Union
import torch
from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint
from litgpt.utils import auto_download_checkpoint, copy_config_files
def prepare_results(results, save_filepath, print_results=True):
from lm_eval.utils import make_table
if print_results:
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
json_result = json.dumps(results, indent=2, ensure_ascii=False, default=str)
save_filepath.open("w", encoding="utf-8").write(json_result)
def convert_and_evaluate(
checkpoint_dir: Path,
tasks: Optional[str] = None,
out_dir: Optional[Path] = None,
force_conversion: bool = False,
num_fewshot: Optional[int] = None,
batch_size: Union[int, str] = 1,
device: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
limit: Optional[float] = None,
seed: int = 1234,
save_filepath: Optional[Path] = None,
access_token: Optional[str] = None,
) -> None:
"""Evaluate a model with the LM Evaluation Harness.
Arguments:
checkpoint_dir: Directory where the `lit_model.pth` and tokenizer files are located.
out_dir: Directory in which to save the converted checkpoints for evaluation.
Saves to `checkpoint_dir`/evaluate by default.
force_conversion: Set to `True` to reconvert the model and override
an existing model.pth from a previous evaluation call.
tasks: CSV of task names to evaluate. Example: "hellaswag,truthfulqa_mc2,mmlu"
num_fewshot: Number of examples in few-shot context.
batch_size: Batch size configuration as positive integer value (default: 1),
"auto", in the format 'auto:N', where 'auto:4' recomputes the batch size 4 times.
device: Device to use for evaluation, for example, "cuda" or "cuda:0".
limit: Limit on number of examples per task.
seed: Random seed.
save_filepath: The file where the results will be saved.
Saves to `out_dir/results.json` by default.
access_token: Optional API token to access models with restrictions.
"""
if tasks is None:
from lm_eval.tasks import TaskManager
taskm = TaskManager()
print("\n".join(taskm.task_index.keys()))
print(
"\n\nTo evaluate multiple tasks, you can chain the task names "
"listed above via a comma-separated list."
"\nFor example: `--tasks 'hellaswag,truthfulqa_mc2,mmlu'`. "
"\nTo search for a specific task, use `litgpt evaluate list | grep task_name`."
)
return
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
if not (isinstance(batch_size, int) and batch_size > 0) and not (
isinstance(batch_size, str) and batch_size.startswith("auto")
):
raise ValueError("batch_size must be a positive integer, 'auto', or in the format 'auto:N'.")
from lm_eval import evaluator
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if out_dir is None:
out_dir = checkpoint_dir / "evaluate"
else:
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
save_filepath = out_dir / Path("results.json") if save_filepath is None else Path(save_filepath)
model_path = out_dir / "pytorch_model.bin"
if not model_path.exists() or force_conversion:
copy_config_files(source_dir=checkpoint_dir, out_dir=out_dir)
convert_lit_checkpoint(checkpoint_dir=checkpoint_dir, output_dir=out_dir)
# Hack: LitGPT's conversion doesn't save a pickle file that is compatible to be loaded with
# `torch.load(..., weights_only=True)`, which is a requirement in HFLM.
# So we're `torch.load`-ing and `torch.save`-ing it again to work around this.
state_dict = torch.load(out_dir / "model.pth")
torch.save(state_dict, model_path)
os.remove(out_dir / "model.pth")
from lm_eval.models.huggingface import HFLM
model = HFLM(pretrained=str(out_dir.resolve()), device=device, batch_size=batch_size, dtype=dtype)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
results = evaluator.simple_evaluate(
model=model,
tasks=tasks.split(","),
num_fewshot=num_fewshot,
batch_size=batch_size,
device=device,
limit=limit,
random_seed=seed,
numpy_random_seed=seed,
torch_random_seed=seed,
)
prepare_results(results, save_filepath)
================================================
FILE: litgpt/finetune/__init__.py
================================================
================================================
FILE: litgpt/finetune/adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import RunningMean
from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
auto_download_checkpoint,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
create_finetuning_performance_report,
get_default_supported_precision,
init_out_dir,
instantiate_bnb_optimizer,
instantiate_torch_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
select_sft_generate_example,
)
def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/adapter"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: LoggerChoice = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the Adapter method.
Arguments:
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
num_nodes: How many nodes the code is being run on.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name,
out_dir,
name=f"finetune-{config.name}",
log_interval=train.log_interval,
log_args=dataclasses.asdict(log),
)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if devices * num_nodes > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
" when using the --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
precision=precision,
loggers=logger,
plugins=plugins,
)
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
config: Config,
data: DataModule,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
validate_args(train, eval)
tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
mark_only_adapter_as_trainable(model)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
model = fabric.setup_module(model)
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(
device=old_embedding.weight.device, dtype=old_embedding.weight.dtype
)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
train_time = time.perf_counter()
token_counts = fit(
fabric=fabric,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
devices=devices,
num_nodes=num_nodes,
checkpoint_dir=checkpoint_dir,
out_dir=out_dir,
train=train,
eval=eval,
data=data,
)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
# Save the final Adapter checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.adapter"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, save_path)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
def fit(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: DataModule,
num_nodes: int = 1,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"
train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()
token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
}
while step_count < max_steps:
iter_num += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
break
input_ids, targets = batch["input_ids"], batch["labels"]
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
step_count += 1
token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item()
token_counts["raw_tokens_plus_prompt_template"] += (
batch["token_counts"]["raw_plus_prompt_template"].sum().item()
)
token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel()
total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
"loss": loss,
"iter": iter_num,
"step": step_count,
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": token_counts["raw_tokens_plus_prompt_template"],
"total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size,
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=iter_num)
if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
fabric.all_reduce(val_time_tensor, reduce_op="mean")
fabric.print(
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
)
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.adapter"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)
save_prompt_style(data.prompt_style, checkpoint_file.parent)
total_token_counts = {}
for key in token_counts:
total = fabric.all_reduce(token_counts[key], reduce_op="sum")
total_token_counts[key] = total.item()
return total_token_counts
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
model.train()
return val_loss
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
max_returned_tokens = len(encoded) + eval.max_new_tokens
if max_returned_tokens < model.max_seq_length:
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. "
f"The model's supported context size (post-training) is {model.config.block_size}."
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving adapter weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/finetune/adapter_v2.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import RunningMean
from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
auto_download_checkpoint,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
create_finetuning_performance_report,
get_default_supported_precision,
init_out_dir,
instantiate_bnb_optimizer,
instantiate_torch_optimizer,
load_checkpoint,
load_checkpoint_update,
num_parameters,
parse_devices,
select_sft_generate_example,
)
def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/adapter-v2"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
resume: Optional[bool] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: LoggerChoice = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the Adapter V2 method.
Arguments:
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
num_nodes: How many nodes the code is being run on.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name,
out_dir,
name=f"finetune-{config.name}",
log_interval=train.log_interval,
log_args=dataclasses.asdict(log),
)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if devices * num_nodes > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
" when using the --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
precision=precision,
loggers=logger,
plugins=plugins,
)
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
config: Config,
data: DataModule,
resume: bool,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
validate_args(train, eval)
tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
mark_only_adapter_v2_as_trainable(model)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
model = fabric.setup_module(model)
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(
device=old_embedding.weight.device, dtype=old_embedding.weight.dtype
)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
if resume:
# Finding last trace of adapter training
try:
resume = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=(lambda p: int(p.parent.name.split("-")[1])))
fabric.print(f"Resuming training from {resume}")
load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False)
resume = True
except ValueError:
fabric.print("No previous adapter found. Finetune from start.")
resume = False
load_checkpoint(fabric, model, checkpoint_path, strict=False)
else:
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
mark_only_adapter_v2_as_trainable(model)
train_time = time.perf_counter()
token_counts = fit(
fabric=fabric,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
devices=devices,
resume=resume,
num_nodes=num_nodes,
checkpoint_dir=checkpoint_dir,
out_dir=out_dir,
train=train,
eval=eval,
data=data,
)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
# Save the final Adapter checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.adapter_v2"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, save_path)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
def fit(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
resume: bool,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: DataModule,
num_nodes: int = 1,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"
train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()
token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
}
if not resume:
try:
iter_match = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=lambda p: int(p.parent.name.split("-")[1]))
step_count = int(iter_match.parent.name.split("-")[1]) if iter_match else 0
except ValueError:
step_count = 0
fabric.print(f"Starting at step count {step_count}")
while step_count < max_steps and train_iterator.epoch < train.epochs:
iter_num += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
break
input_ids, targets = batch["input_ids"], batch["labels"]
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
step_count += 1
token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item()
token_counts["raw_tokens_plus_prompt_template"] += (
batch["token_counts"]["raw_plus_prompt_template"].sum().item()
)
token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel()
total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
"loss": loss,
"iter": iter_num,
"step": step_count,
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": token_counts["raw_tokens_plus_prompt_template"],
"total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size,
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=iter_num)
if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
fabric.all_reduce(val_time_tensor, reduce_op="mean")
fabric.print(
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
)
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.adapter_v2"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)
save_prompt_style(data.prompt_style, checkpoint_file.parent)
total_token_counts = {}
for key in token_counts:
total = fabric.all_reduce(token_counts[key], reduce_op="sum")
total_token_counts[key] = total.item()
return total_token_counts
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
model.train()
return val_loss
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
max_returned_tokens = len(encoded) + eval.max_new_tokens
if max_returned_tokens < model.max_seq_length:
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. "
f"The model's supported context size (post-training) is {model.config.block_size}."
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving adapter v2 weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/finetune/full.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import time
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import RunningMean
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.model import GPT, Block, Config
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
auto_download_checkpoint,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
create_finetuning_performance_report,
find_resume_path,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
select_sft_generate_example,
)
def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/full"),
precision: Optional[str] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
resume: Union[bool, Literal["auto"], Path] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
logger_name: LoggerChoice = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model.
Arguments:
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
devices: How many devices/GPUs to use
num_nodes: How many nodes the code is being run on.
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name,
out_dir,
name=f"finetune-{config.name}",
resume=bool(resume),
log_interval=train.log_interval,
log_args=dataclasses.asdict(log),
)
if devices * num_nodes > 1:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger)
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
def main(
fabric: L.Fabric,
devices: int,
resume: Union[bool, Literal["auto"], Path],
seed: int,
config: Config,
data: DataModule,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
validate_args(train, eval)
tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
model = fabric.setup(model)
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0}
resume = find_resume_path(resume, out_dir)
if resume:
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
else:
load_checkpoint(fabric, state["model"], checkpoint_path)
train_time = time.perf_counter()
token_counts = fit(
fabric=fabric,
state=state,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
devices=devices,
num_nodes=num_nodes,
resume=resume,
checkpoint_dir=checkpoint_dir,
out_dir=out_dir,
train=train,
eval=eval,
data=data,
)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=state["iter_num"])
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
# Save the final checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
fabric.save(save_path, {"model": state["model"]})
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
def fit(
fabric: L.Fabric,
state: Dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
resume: Union[bool, Literal["auto"], Path],
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: DataModule,
num_nodes: int = 1,
) -> None:
model = state["model"]
optimizer = state["optimizer"]
scheduler = state["scheduler"]
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
}
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"
initial_iter = state["iter_num"]
max_steps = train.max_steps or float("inf")
train_iterator = CycleIterator(train_dataloader)
# resume data loader state by fast-forwarding through all seen batches
if resume:
resume_t0 = time.perf_counter()
for resume_iter in range(initial_iter):
next(train_iterator)
if resume_iter % 1000 == 0:
fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}")
fabric.barrier()
fabric.print(
f"Resuming data loader finished. Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration"
f" {initial_iter}."
)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
while state["step_count"] < max_steps:
state["iter_num"] += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
break
input_ids, targets = batch["input_ids"], batch["labels"]
is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
# shift the targets such that output n predicts token n+1
loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
state["step_count"] += 1
token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item()
token_counts["raw_tokens_plus_prompt_template"] += (
batch["token_counts"]["raw_plus_prompt_template"].sum().item()
)
token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel()
if state["iter_num"] % train.log_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
metrics = {
"loss": loss,
"iter": state["iter_num"],
"step": state["step_count"],
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": token_counts["raw_tokens_plus_prompt_template"],
"total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size,
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=state["iter_num"])
if not is_accumulating and state["step_count"] % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
fabric.all_reduce(val_time_tensor, reduce_op="mean")
fabric.print(
f"iter {state['iter_num']}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
)
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
fabric.log_dict(metrics, step=state["iter_num"])
fabric.barrier()
if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{state['step_count']:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file.parent)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)
save_prompt_style(data.prompt_style, checkpoint_file.parent)
total_token_counts = {}
for key in token_counts:
total = fabric.all_reduce(token_counts[key], reduce_op="sum")
total_token_counts[key] = total.item()
return total_token_counts
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
model.train()
return val_loss
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
max_returned_tokens = len(encoded) + eval.max_new_tokens
if max_returned_tokens < model.max_seq_length:
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. "
f"The model's supported context size (post-training) is {model.config.block_size}."
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/finetune/lora.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import RunningMean
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
auto_download_checkpoint,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
create_finetuning_performance_report,
get_default_supported_precision,
init_out_dir,
instantiate_bnb_optimizer,
instantiate_torch_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
select_sft_generate_example,
)
def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/lora"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_query: bool = True,
lora_key: bool = False,
lora_value: bool = True,
lora_projection: bool = False,
lora_mlp: bool = False,
lora_head: bool = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
max_time=None,
),
log: LogArgs = LogArgs(),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
optimizer: Union[str, Dict] = "AdamW",
logger_name: LoggerChoice = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the LoRA method.
Arguments:
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
num_nodes: How many nodes the code is being run on.
lora_r: The LoRA rank.
lora_alpha: The LoRA alpha.
lora_dropout: The LoRA dropout value.
lora_query: Whether to apply LoRA to the query weights in attention.
lora_key: Whether to apply LoRA to the key weights in attention.
lora_value: Whether to apply LoRA to the value weights in attention.
lora_projection: Whether to apply LoRA to the output projection in the attention block.
lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.
lora_head: Whether to apply LoRA to output head in GPT.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(
checkpoint_dir / "model_config.yaml",
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_query=lora_query,
lora_key=lora_key,
lora_value=lora_value,
lora_projection=lora_projection,
lora_mlp=lora_mlp,
lora_head=lora_head,
)
precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name,
out_dir,
name=f"finetune-{config.name}",
log_interval=train.log_interval,
log_args=dataclasses.asdict(log),
)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if devices * num_nodes > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
" when using the --quantize flag."
)
strategy = ModelParallelStrategy(
parallelize_fn=parallelize_fn,
data_parallel_size=devices * num_nodes,
tensor_parallel_size=1,
)
else:
strategy = "auto"
fabric = L.Fabric(
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
precision=precision,
loggers=logger,
plugins=plugins,
)
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch(
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision
)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
config: Config,
data: DataModule,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
precision: Optional[str] = None,
) -> None:
validate_args(train, eval)
tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
mark_only_lora_as_trainable(model)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
model = fabric.setup_module(model)
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(
device=old_embedding.weight.device, dtype=old_embedding.weight.dtype
)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
load_checkpoint(fabric, model, checkpoint_path, strict=False)
train_time = time.perf_counter()
token_counts = fit(
fabric=fabric,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
devices=devices,
num_nodes=num_nodes,
checkpoint_dir=checkpoint_dir,
out_dir=out_dir,
train=train,
eval=eval,
data=data,
)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
# Save the final LoRA checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.lora"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, save_path)
fabric.barrier()
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
merge_lora(
checkpoint_dir=save_path.parent,
pretrained_checkpoint_dir=checkpoint_dir,
precision=precision,
)
fabric.barrier()
def fit(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: DataModule,
num_nodes: int = 1,
) -> dict:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"
train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()
max_time = train.max_time or float("inf")
token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
}
while step_count < max_steps:
iter_num += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
generate_example(fabric, model, tokenizer, eval, data)
fabric.print(f"Number of epochs {train.epochs} reached, stopping training...")
break
if iter_t0 - total_t0 > max_time:
generate_example(fabric, model, tokenizer, eval, data)
fabric.print(f"Max time ({max_time / 60.0:.2f}m) reached, stopping training...")
break
input_ids, targets = batch["input_ids"], batch["labels"]
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
step_count += 1
token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item()
token_counts["raw_tokens_plus_prompt_template"] += (
batch["token_counts"]["raw_plus_prompt_template"].sum().item()
)
token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel()
total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
"loss": loss,
"iter": iter_num,
"step": step_count,
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": token_counts["raw_tokens_plus_prompt_template"],
"total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size,
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=iter_num)
if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
fabric.all_reduce(val_time_tensor, reduce_op="mean")
fabric.print(
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
)
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)
save_prompt_style(data.prompt_style, checkpoint_file.parent)
total_token_counts = {}
for key in token_counts:
total = fabric.all_reduce(token_counts[key], reduce_op="sum")
total_token_counts[key] = total.item()
return total_token_counts
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
model.train()
return val_loss
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
max_returned_tokens = len(encoded) + eval.max_new_tokens
if max_returned_tokens < model.max_seq_length:
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. "
f"The model's supported context size (post-training) is {model.config.block_size}."
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def parallelize_fn(model, device_mesh, activation_checkpointing=True):
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper
if activation_checkpointing:
model.transformer.h = torch.nn.ModuleList(
[checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h]
)
dp_mesh = device_mesh["data_parallel"]
for m in reversed(list(model.modules())):
if (
(isinstance(m, torch.nn.Linear) and m.weight.requires_grad)
or isinstance(m, CheckpointWrapper)
or isinstance(m, Block)
):
fully_shard(m, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)
return model
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
cpu_state_dict = {}
sharded_sd = model.state_dict()
for param_name, param in sharded_sd.items():
if "lora_" not in param_name:
continue
if param.is_cpu:
param = param.to(fabric.device)
if hasattr(param, "_local_tensor"):
param = param.full_tensor()
if fabric.is_global_zero:
cpu_state_dict[param_name] = param.cpu()
fabric.barrier()
if fabric.is_global_zero:
torch.save({"model": cpu_state_dict}, file_path)
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/finetune/lora_legacy.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics import RunningMean
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
auto_download_checkpoint,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
create_finetuning_performance_report,
get_default_supported_precision,
init_out_dir,
instantiate_bnb_optimizer,
instantiate_torch_optimizer,
load_checkpoint,
num_parameters,
parse_devices,
select_sft_generate_example,
)
def setup(
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/lora"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
num_nodes: int = 1,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_query: bool = True,
lora_key: bool = False,
lora_value: bool = True,
lora_projection: bool = False,
lora_mlp: bool = False,
lora_head: bool = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
max_seq_length=None,
),
log: LogArgs = LogArgs(),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
optimizer: Union[str, Dict] = "AdamW",
logger_name: LoggerChoice = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the LoRA method.
Arguments:
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
num_nodes: How many nodes the code is being run on.
lora_r: The LoRA rank.
lora_alpha: The LoRA alpha.
lora_dropout: The LoRA dropout value.
lora_query: Whether to apply LoRA to the query weights in attention.
lora_key: Whether to apply LoRA to the key weights in attention.
lora_value: Whether to apply LoRA to the value weights in attention.
lora_projection: Whether to apply LoRA to the output projection in the attention block.
lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.
lora_head: Whether to apply LoRA to output head in GPT.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(
checkpoint_dir / "model_config.yaml",
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_query=lora_query,
lora_key=lora_key,
lora_value=lora_value,
lora_projection=lora_projection,
lora_mlp=lora_mlp,
lora_head=lora_head,
)
precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(
logger_name,
out_dir,
name=f"finetune-{config.name}",
log_interval=train.log_interval,
log_args=dataclasses.asdict(log),
)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if devices * num_nodes > 1:
if quantize:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
" when using the --quantize flag."
)
strategy = FSDPStrategy(
auto_wrap_policy={torch.nn.Linear},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
precision=precision,
loggers=logger,
plugins=plugins,
)
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
config: Config,
data: DataModule,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
validate_args(train, eval)
tokenizer = Tokenizer(checkpoint_dir)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)
steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)
lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf")))
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
checkpoint_path = checkpoint_dir / "lit_model.pth"
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
mark_only_lora_as_trainable(model)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")
model = fabric.setup_module(model)
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())
from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(
device=old_embedding.weight.device, dtype=old_embedding.weight.dtype
)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
train_time = time.perf_counter()
token_counts = fit(
fabric=fabric,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
devices=devices,
num_nodes=num_nodes,
checkpoint_dir=checkpoint_dir,
out_dir=out_dir,
train=train,
eval=eval,
data=data,
)
training_time = time.perf_counter() - train_time
output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)
fabric.print(output)
# Final evaluation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics)
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
# Save the final LoRA checkpoint at the end of training
save_path = out_dir / "final" / "lit_model.pth.lora"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, save_path)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)
save_prompt_style(data.prompt_style, save_path.parent)
merge_lora(checkpoint_dir=save_path.parent)
def fit(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: DataModule,
num_nodes: int = 1,
) -> dict:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(
ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])
)
model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf"))
fabric.print(
f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
f" {model.max_seq_length} and context length is {model.config.block_size}"
)
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False) # sanity check
val_loss = "n/a"
train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
total_lengths = 0
total_t0 = time.perf_counter()
token_counts = {
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
}
while step_count < max_steps:
iter_num += 1
iter_t0 = time.perf_counter()
batch = next(train_iterator)
if train_iterator.epoch >= train.epochs:
break
input_ids, targets = batch["input_ids"], batch["labels"]
is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
step_count += 1
token_counts["raw_tokens"] += batch["token_counts"]["raw"].sum().item()
token_counts["raw_tokens_plus_prompt_template"] += (
batch["token_counts"]["raw_plus_prompt_template"].sum().item()
)
token_counts["raw_tokens_plus_prompt_template_and_padding"] += input_ids.numel()
total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
"loss": loss,
"iter": iter_num,
"step": step_count,
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": token_counts["raw_tokens_plus_prompt_template"],
"total_tokens": token_counts["raw_tokens_plus_prompt_template"] * fabric.world_size,
"learning_rate": scheduler.get_last_lr()[0],
}
if isinstance(val_loss, torch.Tensor):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=iter_num)
if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, eval)
generate_example(fabric, model, tokenizer, eval, data)
t1 = time.perf_counter() - t0
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
fabric.all_reduce(val_time_tensor, reduce_op="mean")
fabric.print(
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
)
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)
save_prompt_style(data.prompt_style, checkpoint_file.parent)
total_token_counts = {}
for key in token_counts:
total = fabric.all_reduce(token_counts[key], reduce_op="sum")
total_token_counts[key] = total.item()
return total_token_counts
# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True
) -> torch.Tensor:
if verbose:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
val_loss = losses.mean()
model.train()
return val_loss
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
model.eval()
max_returned_tokens = len(encoded) + eval.max_new_tokens
if max_returned_tokens < model.max_seq_length:
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
output = generate(
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
)
model.clear_kv_cache()
model.train()
output = tokenizer.decode(output)
fabric.print(f"{output}\n")
else:
print(
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "
f"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. "
f"The model's supported context size (post-training) is {model.config.block_size}."
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
return train_dataloader, val_dataloader
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
lengths = [len(d["input_ids"]) for d in data]
longest_seq_length = max(lengths)
longest_seq_ix = lengths.index(longest_seq_length)
return longest_seq_length, longest_seq_ix
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": lora_filter})
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
issues = []
unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
required = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if not train.epochs and not train.max_steps:
issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/generate/__init__.py
================================================
================================================
FILE: litgpt/generate/adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt import PromptStyle, Tokenizer
from litgpt.adapter import GPT, Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
lazy_load,
)
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
sys_prompt: Optional[str] = None,
adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
"""For models finetuned with `litgpt finetune_adapter`.
Generates a response based on a given instruction and an optional input. This script will only work with
checkpoints from the instruction-tuned adapter model. See ``litgpt.finetune.adapter``.
Args:
checkpoint_dir: The path to the checkpoint folder with pretrained model weights.
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
sys_prompt: Optional system prompt.
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
``litgpt.finetune.adapter``.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model = fabric.setup(model)
L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/adapter_v2.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt import PromptStyle, Tokenizer
from litgpt.adapter_v2 import GPT, Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
lazy_load,
)
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
sys_prompt: Optional[str] = None,
adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
"""For models finetuned with `litgpt finetune adapter_v2`.
Generates a response based on a given instruction and an optional input. This script will only work with
checkpoints from the instruction-tuned adapter v2 model. See ``litgpt.finetune.adapter_v2``.
Args:
checkpoint_dir: The path to the checkpoint folder with pretrained model weights.
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
sys_prompt: Optional system prompt.
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
``litgpt.finetune.adapter_v2``.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model = fabric.setup(model)
L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/base.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt.config import Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.model import GPT
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
)
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
if torch._dynamo.is_compiling():
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
distribution = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Example:
# sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
# sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least 1 token always to prevent the case where no token is selected
# In this case the most probable one is always kept
sorted_indices_to_remove[-1:] = 0
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
return logits
def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0
) -> torch.Tensor:
if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be in [0, 1], got {top_p}")
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0 or top_p > 0.0:
if temperature > 0.0:
logits = logits / temperature
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p
if top_p < 1.0:
logits = sample_top_p(logits, top_p)
probs = torch.nn.functional.softmax(logits, dim=-1)
return multinomial_num_samples_1(probs)
return torch.argmax(logits, dim=-1, keepdim=True)
def next_token(
model: GPT,
input_pos: torch.Tensor,
x: torch.Tensor,
input_pos_maxp1: Optional[int] = None,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
_next = sample(logits, **sample_kwargs).to(dtype=torch.int64)
return _next
def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor:
assert len(logits) == len(kwargs), "logits and kwargs must have the same length."
return torch.stack(
[sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0
)
def batched_next_token(
model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]]
) -> torch.Tensor:
# Where:
# input_pos is a 1d tensor of shape [seq_length...]
# x is context tokens to add to the kvcache.
# For prefill, x is a 2d tensor of shape [batch_size, prompt_length].
# For subsequent tokens, x is a 2d tensor of shape [batch_size, 1].
# kwargs is a list of dictionaries, each containing the keyword arguments for the sample function.
# If one dictionary is passed, it's repeated for each sample in the batch.
# In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length].
# That way, we can support prompts of different sizes.
# This means making the rope cache and kvcache forward() work with batches. Currently, they do not.
# This is relatively complicated, given the current implementation. It will require some rewriting.
# Relevant thread: https://discuss.pytorch.org/t/batched-index-select/9115
# We will also need the same with tensor.index_copy_(). These do not work for batches, and the replacement
# is somewhat nontrivial. Until then, we can only accept prompts that are all the same length.
# After this problem is resolved, there will be another problem. That being, continuous batched prefill.
# If you have any ideas on this, let me know. I don't think that padding input_pos is viable.
_kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0)
# Run the model on the batch.
logits_stack = model(x, input_pos)
# Unbind the logits stack into a list of logits.
logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0)
logits_list = [l.unsqueeze(0) for l in logits_list]
# Return the next token for each sample in the batch.
return batched_sample(logits_list, kwargs=_kwargs)
@torch.inference_mode()
def generate_fn(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool,
include_eos: bool,
) -> Iterator[torch.Tensor]:
"""
Generates tokens for a single prompt.
Args:
model: The model to use.
prompt: The tokenized prompt to generate from.
max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens.
temperature: The temp to pass to sample().
top_k: The top_k to pass to sample().
top_p: The top_p to pass to sample().
stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.
include_prompt: Whether to output the prompt tokens.
include_eos: Whether to output the stop tokens if generation stops early.
"""
prompt_size = prompt.size(0)
device = prompt.device
assert max_returned_tokens > prompt_size, (
f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}."
)
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
# Yield the prompt if include_prompt is True
if include_prompt:
yield prompt
stop_progress = [0] * len(stop_tokens)
yielded_idx = 0
# Generate output tokens.
# The first token generated is the prefill token.
# The input_pos for this token is the width of the entire prompt.
# For subsequent iterations, it's the index in the context for the token that we're generating.
tokens = []
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
# input_pos_maxp1 introduces data-dependent shapes and control flow.
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None
for current_idx in range(max_returned_tokens - prompt_size):
# Generate the token
token = next_token(
model,
input_pos,
token.view(1, -1),
input_pos_maxp1=input_pos_maxp1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
tokens.append(token)
int_token = token.item()
# Check for stop sequences
# For each stop sequence, we keep a running total of how many are matched in stop_progress.
# If the current token matches the next token in the stop sequence, we increment the
# running total and hold off on yielding the token.
for i, seq in enumerate(stop_tokens):
if int_token == seq[stop_progress[i]]:
stop_progress[i] += 1
if stop_progress[i] == len(seq):
if include_eos:
yield from tokens[yielded_idx:]
return
else:
stop_progress[i] = 0
# Yield tokens that are not part of a stop sequence in progress.
# If there are no stop sequences, then that's all of them.
if stop_tokens:
safe_idx = len(tokens) - max(stop_progress)
else:
safe_idx = current_idx + 1 # include the token just generated
if yielded_idx < safe_idx:
y_tokens = tokens[yielded_idx:safe_idx]
yield from y_tokens
yielded_idx = safe_idx
# Update input_pos for the next iteration.
if prefill_token:
prefill_token = False
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)
if input_pos_maxp1 is not None:
input_pos_maxp1 += 1
# Yield any remaining tokens
if yielded_idx < len(tokens):
yield from tokens[yielded_idx:]
# TODO: Make include_eos work.
# TODO: Rewrite unbatched generate_fn to use batched_generate_fn.
@torch.inference_mode()
def batched_generate_fn(
model: GPT,
prompts: torch.Tensor,
max_returned_tokens: int,
*,
sample_args: Union[list[dict], dict],
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool,
include_eos: bool,
) -> Iterator[list[Union[torch.Tensor, None]]]:
"""
Generates tokens for a batch of prompts.
Args:
model: The model to use.
prompts: A 2D tensor of shape [batch_size, prompt_length].
max_returned_tokens: The maximum number of tokens to return, including the prompt tokens.
sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch.
stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.
include_prompt: Whether to output the prompt tokens.
include_eos: Whether to output the stop tokens if generation stops early.
Yields:
A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch.
"""
if prompts.ndim == 1:
prompts = prompts.unsqueeze(0)
assert prompts.ndim == 2, "Prompts must be a 2D tensor."
batch_size = prompts.size(0)
max_prompt_size = prompts.size(1)
device = prompts.device
if isinstance(sample_args, dict):
sample_args = [sample_args] * len(prompts)
else:
assert len(sample_args) == batch_size, "sample_args must have the length as the batch size."
# TODO: This check (and the one in generate_fn) is not sufficient. We do the proper checks in LLM.generate().
assert max_returned_tokens > max_prompt_size, (
f"Not enough space for {max_prompt_size} prompt tokens in a context length of {max_returned_tokens}."
)
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
# Yield the prompts if include_prompt is True
if include_prompt:
# TODO: Prompt length is padded, but they shouldn't all be the same length.
for i in range(max_prompt_size):
yield [prompt[i].view(-1) for prompt in prompts]
stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)]
stop_idxes = [-1] * batch_size
yielded_idx = 0
# Generate output tokens.
# The first token generated is the prefill token.
# The input_pos for this token is the width of the entire prompt.
# For subsequent iterations, it's the index in the context for the token that we're generating.
token_lists = [[] for _ in range(batch_size)]
tokens: torch.Tensor = prompts
prefill_token = True
input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64)
for current_idx in range(max_returned_tokens - max_prompt_size):
# Generate the next token for each prompt in the batch.
# This is of shape [batch_size, 1].
tokens = batched_next_token(model, input_pos, tokens, sample_args)
for i in range(batch_size):
token_lists[i].append(tokens[i])
int_tokens = [token.item() for token in tokens]
# Check for stop sequences
# For each stop sequence, we keep a running total of how many are matched in stop_progress.
# If the current token matches the next token in the stop sequence, we increment the
# running total and hold off on yielding the token.
for batch_idx, int_token in enumerate(int_tokens):
if stop_idxes[batch_idx] != -1:
continue
for seq_idx, seq in enumerate(stop_tokens):
seq_pos = stop_progresses[batch_idx][seq_idx]
if seq_pos >= len(seq):
continue
if int_token == seq[seq_pos]:
stop_progresses[batch_idx][seq_idx] += 1
if stop_progresses[batch_idx][seq_idx] == len(seq):
stop_idxes[batch_idx] = current_idx
else:
stop_progresses[batch_idx][seq_idx] = 0
# Yield tokens that are not part of a stop sequence in progress.
# If there are no stop sequences, then that's all of them.
if len(stop_tokens) != 0:
safe_idxes = [len(token_lists[i]) - max(stop_progresses[i]) for i in range(batch_size)]
else:
safe_idxes = [current_idx + 1] # include the token just generated
safe_idx = min(safe_idxes)
if yielded_idx < safe_idx:
for idx in range(yielded_idx, safe_idx):
y_tokens = [
token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None
for i in range(batch_size)
]
if all(y is None for y in y_tokens):
return
yield y_tokens
yielded_idx = safe_idx
# Update input_pos for the next iteration.
if prefill_token:
prefill_token = False
# TODO: Make the model support a batched input_pos of shape [batch_size, 1].
# The kvcache has been fixed, but the rope cache is still broken.
input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)
# Yield any remaining tokens
max_token_lists = max(len(l) for l in token_lists)
if yielded_idx < max_token_lists:
for idx in range(yielded_idx, max_token_lists):
y_tokens = [
token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)
]
if all(y is None for y in y_tokens):
return
yield y_tokens
return
@torch.inference_mode()
def generate(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id: Optional[int] = None,
include_prompt: bool = True,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
eos_id: If specified, stop generating any more token once the token is triggered.
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
"""
token_list = list(
generate_fn(
include_prompt=include_prompt,
include_eos=True,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=(([eos_id],) if eos_id is not None else ()),
)
)
return torch.cat(token_list) if not len(token_list) == 0 else torch.Tensor()
@torch.inference_mode()
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
sys_prompt: Optional[str] = None,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
"""Default generation option.
Generates text samples based on a pre-trained model and tokenizer.
Args:
checkpoint_dir: The checkpoint directory to load.
prompt: The prompt string to use for generating the samples.
sys_prompt: The system prompt to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead")
model = fabric.setup_module(model)
t0 = time.perf_counter()
load_checkpoint(fabric, model, checkpoint_path)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/full.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt import GPT, Config, PromptStyle, Tokenizer
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
)
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
sys_prompt: Optional[str] = None,
finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
"""For models finetuned with `litgpt finetune_full`.
Generates a response based on a given instruction and an optional input. This script will only work with
checkpoints from the instruction-tuned model. See ``litgpt.finetune.full``.
Args:
checkpoint_dir: The path to the checkpoint folder with pretrained model weights.
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
sys_prompt: Optional system prompt.
finetuned_path: Path to the checkpoint with trained weights, which are the output of
``litgpt.finetune.full``.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = finetuned_path
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
model = fabric.setup(model)
t0 = time.perf_counter()
load_checkpoint(fabric, model, checkpoint_path)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/sequentially.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import itertools
import logging
import re
import sys
import time
import warnings
from collections import OrderedDict
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import List, Literal, Optional, Type
import lightning as L
import torch
from lightning.fabric.accelerators import CUDAAccelerator
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.utilities.init import _materialize_meta_tensors
from tqdm import tqdm
import litgpt.generate.base as generate_base
from litgpt.config import Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.model import GPT, Block, build_mask_cache
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
)
@torch.inference_mode()
def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int):
if model.config.n_layer < devices:
raise ValueError(
f"The number of layers in the model must be larger than the number of devices, but got"
f" n_layer={model.config.n_layer} and devices={devices}."
)
# Dictates where each block should be instantiated
mapping = layer_to_device(
model,
chunk_on=Block,
chunk_sizes=chunk_sizes(model.config.n_layer, devices),
)
num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)}
# materialize each block on the appropriate device
with tqdm(total=len(mapping), desc="Moving submodules") as pbar:
for path, target_index in mapping.items():
submodule = model.get_submodule(path)
target_device = torch.device(root.type, target_index)
pbar.set_description(f"Moving {path!r} to {target_device}")
pbar.update(1)
# submodules loaded by the checkpoint will be on CPU (if no quantization). move them
replace_device(submodule, replace=torch.device("cpu"), by=target_device)
# in case the checkpoint was partial, materialize leftover metas
_materialize_meta_tensors(submodule, target_device)
# and build the kv cache
submodule.attn.kv_cache = submodule.attn.build_kv_cache(
1, max_seq_length, model.rope_cache_length(), target_device
)
# rebuild odd ends
with root:
model.max_seq_length = max_seq_length
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers
model.mask_cache = build_mask_cache(max_seq_length)
# and everything that is not a block in the root
_materialize_meta_tensors(model, root)
replace_device(model, replace=torch.device("cpu"), by=root)
if devices > 1:
# install hooks to move layer inputs/output between devices
for layer_num, (path, target_index) in enumerate(mapping.items()):
submodule = model.get_submodule(path)
if layer_num >= num_layers_per_device[target_index]:
# we need to move the block input on the boundaries between devices
# and also on every non-root device because the RoPE and mask cache is shared
# TODO: the second case could be optimized and then we would only need this hook for
# `layer_num in [layers_per_rank * i - 1 for i in range(1, devices + 1)]`
target_device = torch.device(root.type, target_index)
submodule.register_forward_pre_hook(partial(move_block_input, target_device))
if layer_num == model.config.n_layer - 1:
submodule.register_forward_hook(partial(move_block_output, root))
return model
def chunk_sizes(num_units: int, devices: int) -> List[int]:
cs = num_units // devices
k = devices * (cs + 1) - num_units
return [cs] * k + [cs + 1] * (devices - k)
def layer_to_device(
module: torch.nn.Module,
chunk_on: Type[torch.nn.Module],
chunk_sizes: List[int],
) -> "OrderedDict[str, int]":
"""Create a mapping from layer (block) to device."""
# this assumes that the definition order is the same as the execution order
hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)]
if sum(chunk_sizes) != len(hits):
raise ValueError(f"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}")
_devices = [[d] * cs for d, cs in enumerate(chunk_sizes)]
devices = [d for lst in _devices for d in lst]
return OrderedDict(zip(hits, devices))
def move_block_input(device: torch.device, module: torch.nn.Module, ins):
"""``forward_pre_hook`` to move a Block's input before forward."""
# during inference, none of the inputs are None: x, cos, sin, mask, input_pos
return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins)
def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor:
"""``forward_hook`` to move a Block's output after forward."""
return outs.to(device)
def replace_device(module: torch.nn.Module, replace: torch.device, by: torch.device) -> torch.nn.Module:
for name, submodule in module.named_modules():
tensors = dict(
itertools.chain(submodule.named_parameters(recurse=False), submodule.named_buffers(recurse=False))
)
if not tensors:
continue
devices = {t.device for t in tensors.values()}
if len(devices) != 1:
# since this is using `submodule.to`, different devices in the same submodule is a problem
path_to_device = {f"{name}.{p}": t.device for p, t in tensors.items()}
raise ValueError(f"Found multiple devices: {path_to_device}")
if devices.pop() == replace:
submodule.to(by)
return module
@torch.inference_mode()
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
sys_prompt: Optional[str] = None,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
"""Generation script that partitions layers across devices to be run sequentially.
Generates text samples based on a pre-trained model and tokenizer.
Args:
checkpoint_dir: The checkpoint directory to load.
prompt: The prompt string to use for generating the samples.
sys_prompt: The system prompt to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None:
if compile:
raise NotImplementedError # untested
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes").setLevel(logging.DEBUG)
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, accelerator="cuda", plugins=plugins)
total_devices = CUDAAccelerator.auto_device_count()
print(f"Using {total_devices} devices", file=sys.stderr)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
# cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced
# which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert
# still, use init_tensor for the precision
with fabric.init_tensor(), torch.device("meta"):
model = GPT(config)
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
t0 = time.perf_counter()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
# TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware
model.load_state_dict(state_dict, assign=True)
print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model = fabric.setup_module(model, move_to_device=False)
t0 = time.perf_counter()
model = sequential(model, fabric.device, max_returned_tokens, total_devices)
print(f"Time to sequential-ize the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
if compile:
# TODO: raises an internal compile AssertionError caused by fabric.strategy.precision.forward_context
raise NotImplementedError
# silence developer warning on nightly builds
# https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/ir.py#L4166
pattern = re.compile(".*DeviceCopy in input program.*")
logging.getLogger("torch._inductor.utils").addFilter(lambda record: not pattern.search(record.getMessage()))
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
# cannot use cudagraphs because it doesn't support multiple device indices
# https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375
generate_base.next_token = torch.compile(generate_base.next_token)
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate_base.generate(
model=model,
prompt=encoded,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/speculative_decoding.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config
import torch.nn.functional as F
from lightning.fabric.plugins import BitsandbytesPrecision
from litgpt.config import Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.generate.base import multinomial_num_samples_1, next_token, sample_top_p
from litgpt.model import GPT
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
)
def sample(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
apply_softmax: bool = True,
) -> torch.Tensor:
if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be in [0, 1], got {top_p}")
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
fill_value = float("-inf") if apply_softmax else float(0)
logits = torch.full_like(logits, fill_value).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0 or top_p > 0.0:
if temperature > 0.0:
logits = logits / temperature
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p
if top_p < 1.0:
logits = sample_top_p(logits, top_p)
probs = F.softmax(logits, dim=-1) if apply_softmax else logits
return multinomial_num_samples_1(probs), probs
return torch.argmax(logits, dim=-1, keepdim=True), F.softmax(logits, dim=-1)
def speculative_decoding(
draft_model: GPT,
target_model: GPT,
token: torch.Tensor,
input_pos: torch.Tensor,
input_pos_maxp1: int,
speculative_k: int,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
"""Performs speculative decoding using a draft and a target model.
This implements the speculative decoding algorithm from "Fast Inference from Transformers via Speculative Decoding"
(https://arxiv.org/pdf/2211.17192).
The core idea is to:
1. Use a faster draft model to predict multiple tokens ahead
2. Verify those predictions with the slower but more accurate target model
3. Accept tokens where the target model agrees with high probability
4. Reject and resample tokens where there is a disagreement
This allows leveraging a smaller/faster model to speed up generation while maintaining
the quality of the larger target model.
Args:
draft_model: Smaller/faster model used for initial token predictions
target_model: Larger/slower model used for verification
token: Current input token tensor of shape [1]
input_pos: Position index of the token tensor for KV-cache
input_pos_maxp1: Maximum position + 1 for managing KV-cache buffer
speculative_k: Number of tokens to speculatively generate at once
sample_kwargs: Additional sampling parameters (temperature, top_k, top_p)
Returns:
torch.Tensor: Generated tokens that were either accepted from draft model
or resampled from target model
"""
if speculative_k < 1:
raise ValueError(f"speculative_k must be >= 1, got {speculative_k}")
# Step 1: Generate candidate tokens using draft model
# The draft model autoregressively generates k tokens, keeping track of probabilities
draft_input_pos = input_pos.clone()
draft_input_pos_maxp1 = input_pos_maxp1
draft_tokens, draft_probs = [], []
draft_token = token
for idx in range(speculative_k):
logits = draft_model(
idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1
)
draft_token, draft_prob = sample(logits, **sample_kwargs)
draft_input_pos.add_(1)
draft_input_pos_maxp1 += 1
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
draft_tokens = torch.cat(draft_tokens)
# Step 2: Get target model predictions for comparison
# Feed both original token and draft tokens to get target probabilities
candidate_tokens = torch.cat((token, draft_tokens))
candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device)
candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k
target_logits = target_model(
idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1
)
# Step 3: Convert target logits to probabilities using same sampling params
target_probs = []
for target_logit in target_logits.split(1, dim=1):
_, target_prob = sample(target_logit, **sample_kwargs)
target_probs.append(target_prob)
# Step 4: Accept/reject draft tokens based on probability comparison
# Using rejection sampling: keep token if target_prob >= draft_prob.
# Otherwise reject with probability 1 - target_prob / draft_prob.
# If rejected, sample from an adjusted distribution: norm(max(0, target_prob_distribution - draft_prob_distribution) instead.
accepted_tokens = []
for idx in range(len(draft_tokens)):
draft_token = draft_tokens[idx].unsqueeze(0)
draft_prob = draft_probs[idx][draft_token]
target_prob = target_probs[idx][draft_token]
# Accept the draft token if the target model is "confident" in it
if target_prob >= draft_prob:
accepted_tokens.append(draft_token)
continue
# If not accepted, probabilistically reject it
discard_prob = 1 - target_prob / draft_prob
should_discard_token = torch.rand(1, device=discard_prob.device) <= discard_prob
if not should_discard_token:
accepted_tokens.append(draft_token)
continue
# On rejection: sample new token from adjusted distribution
# p'(x) = normalize(max(0, p_target(x) - p_draft(x)))
adjusted_distribution = target_probs[idx] - draft_probs[idx]
adjusted_distribution = torch.clamp(adjusted_distribution, 0.0)
adjusted_distribution = adjusted_distribution / adjusted_distribution.sum()
new_token, _ = sample(adjusted_distribution[None, None, ...], apply_softmax=False, **sample_kwargs)
return torch.cat((*accepted_tokens, new_token))
# If all draft tokens were accepted:
# 1. Update draft model's key-value cache
# 2. Sample one more token from target model
draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1)
new_token, _ = sample(target_logits, **sample_kwargs)
return torch.cat((*accepted_tokens, new_token))
@torch.inference_mode()
def generate(
draft_model: GPT,
target_model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool = True,
speculative_k: int,
) -> Iterator[torch.Tensor]:
"""Generates tokens using speculative decoding with a draft and a target model.
This function implements token generation using speculative decoding, where a faster draft model
makes initial token predictions that are verified by a slower but more accurate target model.
Args:
draft_model: Smaller/faster model used for initial token predictions
target_model: Larger/more accurate model used to verify draft predictions
prompt: Input tensor of token ids to generate from, shape [sequence_length]
max_returned_tokens: Maximum total tokens (prompt + generated) to return
temperature: Sampling temperature (higher = more random, lower = more deterministic)
top_k: If set, only sample from the top k most likely next tokens
top_p: If <1.0, only sample from tokens whose cumulative probability exceeds top_p
stop_tokens: List of token sequences that will stop generation if produced
include_prompt: Whether to include prompt tokens in the returned sequence
speculative_k: Number of tokens to speculatively generate at each step
Returns:
- tokens: Tensor of generated token ids
- acceptance_rate: Ratio of accepted draft model predictions
This implements an optimized decoding process:
1. Both models process the initial prompt
2. Draft model speculatively generates k tokens ahead
3. Target model verifies the draft predictions
4. Accepted tokens are kept, rejected ones trigger resampling
5. Process repeats until max tokens or stop sequence reached
"""
prompt_size = prompt.size(0)
device = prompt.device
assert max_returned_tokens > prompt_size, (
f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}."
)
if draft_model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {draft_model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
if target_model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {target_model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
# Step 1: Prefill draft and target models with the prompt.
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
input_pos_maxp1 = (
prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None
)
next_token(
draft_model,
input_pos,
prompt.view(1, -1),
input_pos_maxp1=input_pos_maxp1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
token = next_token(
target_model,
input_pos,
prompt.view(1, -1),
input_pos_maxp1=input_pos_maxp1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
# Update position trackers after prompt
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
input_pos_maxp1 += 1
# Step 2: Main generation loop.
tokens = []
total_generated, total_accepted = 0, 0 # Track acceptance statistics
while input_pos < max_returned_tokens - 1:
# Calculate speculative tokens to generate
_speculative_k = min(speculative_k, (max_returned_tokens - input_pos - 1).item())
# Get new tokens via speculative decoding
new_tokens = speculative_decoding(
draft_model=draft_model,
target_model=target_model,
token=token,
input_pos=input_pos,
input_pos_maxp1=input_pos_maxp1,
speculative_k=_speculative_k,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
# Update statistics
accepted_tokens_len = len(new_tokens)
total_generated += _speculative_k
total_accepted += accepted_tokens_len - 1 # accepted +1 sampled from a target model
# Process tokens and check for stop condition
should_break = False
for new_token in new_tokens:
if new_token in stop_tokens:
should_break = True
break
tokens.append(new_token)
if should_break:
break
# Update positions for next iteration
input_pos.add_(accepted_tokens_len)
input_pos_maxp1 += accepted_tokens_len
token = new_tokens[-1].unsqueeze(0)
# Finalize generated sequence
tokens = torch.stack(tokens)
if include_prompt:
tokens = torch.cat([prompt, tokens])
acceptance_rate = total_accepted / total_generated if total_generated > 0 else 0.0
return tokens, acceptance_rate
def setup_model(config: Config, max_returned_tokens: int, fabric: L.Fabric) -> GPT:
"""Helper function to setup a model with common configuration."""
with fabric.init_module(empty_init=True):
model = GPT(config)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
return fabric.setup_module(model)
def load_model(checkpoint_dir: Path, fabric: L.Fabric) -> Tuple[Config, Path]:
"""Helper function to validate and load model configuration."""
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
return config, checkpoint_path
@torch.inference_mode()
def main(
draft_model_checkpoint_dir: Path,
target_model_checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
sys_prompt: Optional[str] = None,
num_samples: int = 1,
max_new_tokens: int = 50,
speculative_k: int = 3,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
"""Default generation option.
Generates text samples based on pre-trained models and a tokenizer.
Args:
draft_model: Smaller/faster model used for initial token predictions
target_model: Larger/more accurate model used to verify draft predictions
prompt: The prompt string to use for generating the samples.
sys_prompt: The system prompt to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
speculative_k: Number of tokens to speculatively generate at each step
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
draft_model_checkpoint_dir = extend_checkpoint_dir(draft_model_checkpoint_dir)
target_model_checkpoint_dir = extend_checkpoint_dir(target_model_checkpoint_dir)
pprint(locals())
# Setup Fabric
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
# Load model configs and checkpoints
draft_config, draft_checkpoint_path = load_model(draft_model_checkpoint_dir, fabric)
target_config, target_checkpoint_path = load_model(target_model_checkpoint_dir, fabric)
# Setup tokenizer and validate
draft_tokenizer = Tokenizer(draft_model_checkpoint_dir)
target_tokenizer = Tokenizer(target_model_checkpoint_dir)
if draft_tokenizer.vocab_size != target_tokenizer.vocab_size:
raise ValueError("Draft and target models have different vocab sizes.")
tokenizer = target_tokenizer
# Setup prompt
prompt_style = (
load_prompt_style(target_model_checkpoint_dir)
if has_prompt_style(target_model_checkpoint_dir)
else PromptStyle.from_config(target_config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
# Initialize models
fabric.print(f"Loading draft model {str(draft_checkpoint_path)!r} with {draft_config.__dict__}", file=sys.stderr)
fabric.print(f"Loading target model {str(target_checkpoint_path)!r} with {target_config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
draft_model = setup_model(draft_config, max_returned_tokens, fabric)
target_model = setup_model(target_config, max_returned_tokens, fabric)
fabric.print(f"Time to instantiate models: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
# Setup compilation if needed
if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead")
# Load model weights
t0 = time.perf_counter()
load_checkpoint(fabric, draft_model, draft_checkpoint_path)
load_checkpoint(fabric, target_model, target_checkpoint_path)
fabric.print(f"Time to load the models weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
# Generate samples
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y, acceptance_rate = generate(
draft_model,
target_model,
encoded,
max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=([tokenizer.eos_id] if tokenizer.eos_id is not None else []),
speculative_k=speculative_k,
)
t = time.perf_counter() - t0
# Reset KV cache
for model in (draft_model, target_model):
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
# Print results
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Acceptance rate: {acceptance_rate * 100:.2f}%")
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/generate/tp.py
================================================
"""Tensor-parallel implementation adapted from https://github.com/pytorch-labs/gpt-fast/blob/14df27/tp.py"""
import logging
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional, Union
import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.utilities import rank_zero_only
import litgpt.generate.base as generate_base
from litgpt.config import Config
from litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0
from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_nvlink_connectivity,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
)
def tensor_parallel_linear(fabric: L.Fabric, linear: torch.nn.Linear, style: str) -> None:
world_size = fabric.world_size
dim, attr = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}[style]
size = getattr(linear, attr)
if size % world_size != 0:
raise ValueError(
f"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})"
)
shard = torch.tensor_split(linear.weight, world_size, dim=dim)[fabric.global_rank]
# overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support.
# the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses
linear.weight.data = shard
setattr(linear, attr, shard.size(dim))
if linear.bias is not None and dim == 0:
shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank]
linear.bias = torch.nn.Parameter(shard, requires_grad=linear.bias.requires_grad)
def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:
if isinstance(mlp, LLaMAMLP):
tensor_parallel_linear(fabric, mlp.fc_1, "colwise")
tensor_parallel_linear(fabric, mlp.fc_2, "colwise")
tensor_parallel_linear(fabric, mlp.proj, "rowwise")
mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))
elif isinstance(mlp, GptNeoxMLP):
tensor_parallel_linear(fabric, mlp.fc, "colwise")
tensor_parallel_linear(fabric, mlp.proj, "rowwise")
mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))
elif isinstance(mlp, LLaMAMoE):
# we use expert slicing across ranks, alternatively, we could create a expert parallelism group
# when the number of experts is a multiple of the world size
for expert in mlp.experts:
tensor_parallel_mlp(fabric, expert)
else:
raise NotImplementedError
def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:
tensor_parallel_linear(fabric, attn.qkv, "colwise")
tensor_parallel_linear(fabric, attn.proj, "rowwise")
attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))
def all_reduce_output(world_size: int, module: torch.nn.Module, ins, outs) -> torch.Tensor:
from torch.distributed._functional_collectives import all_reduce
return all_reduce(outs, "sum", list(range(world_size)))
def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT:
for block in model.transformer.h:
tensor_parallel_mlp(fabric, block.mlp)
tensor_parallel_attn(fabric, block.attn)
# update the config values to the shard sizes
# this is only relevant for `tensor_parallel_attn`, but it needs to run only once
world_size = fabric.world_size
attrs = ["n_head", "n_embd", "n_query_groups"]
for attr in attrs:
size = getattr(model.config, attr)
if size % world_size != 0:
raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})")
setattr(model.config, attr, size // world_size)
return model
@torch.inference_mode()
def main(
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
sys_prompt: Optional[str] = None,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 50,
top_p: float = 1.0,
temperature: float = 0.8,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
"""Generation script that uses tensor parallelism to run across devices.
Generates text samples based on a pre-trained model and tokenizer.
Args:
checkpoint_dir: The checkpoint directory to load.
prompt: The prompt string to use for generating the samples.
sys_prompt: The system prompt to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None:
if compile:
raise NotImplementedError # untested
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:
warnings.warn(
"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization."
)
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
bnb_logger = logging.getLogger("lightning.fabric.plugins.precision.bitsandbytes")
bnb_logger.setLevel(logging.DEBUG)
bnb_logger.debug = rank_zero_only(bnb_logger.debug)
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
# set "ddp" as the strategy for the launching functionality, but there's no data-parallelism
fabric = L.Fabric(devices="auto", strategy="ddp", precision=precision, plugins=plugins)
if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1:
check_nvlink_connectivity(fabric)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
)
prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
# cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced
# which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert
# still, use init_tensor for the precision
with fabric.init_tensor(), torch.device("meta"):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
# sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device
# so that the CPU RAM doesn't OOM with larger models
for rank in range(fabric.world_size):
if fabric.global_rank == rank:
t0 = time.perf_counter()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
# cannot use `.setup_module` because it will wrap with DDP
model = fabric._precision.convert_module(model)
t0 = time.perf_counter()
model = tensor_parallel(fabric, model)
print(
f"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.",
file=sys.stderr,
)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
t0 = time.perf_counter()
model = fabric.to_device(model)
print(f"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
fabric.barrier()
if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
generate_base.next_token = torch.compile(generate_base.next_token, mode="reduce-overhead")
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate_base.generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
================================================
FILE: litgpt/lora.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
# Derived from https://github.com/microsoft/LoRA
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
r"""
Low Ranking Adaptation for LLMs scheme.
┌───────────────────┐
┆ h ┆
└───────────────────┘
▲
|
+
/ \
┌─────────────────┐ ╭───────────────╮ Matrix initialization:
┆ ┆ \ B / B = 0
┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
┆ weights ┆ ╰─────────╯
┆ ┆ | r | r - rank
┆ W e R^(d*d) ┆ | ◀─────▶ |
┆ ┆ ╭─────────╮
└─────────────────┘ / A \
▲ / d*r \
\ ╰───────────────╯
\ ▲
\ /
\ /
┌───────────────────┐
┆ x ┆
└───────────────────┘
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
pretrained weights and thus fine-tune the model.
The goal of this approach is to move weight updates into a separate matrix which is decomposed with
two matrices of a lower rank.
"""
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self
import litgpt
from litgpt.config import Config as BaseConfig
from litgpt.model import GPT as BaseModel
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights
class LoRALayer(nn.Module):
def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
"""Store LoRA specific attributes in a class.
Args:
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
"""
super().__init__()
assert r >= 0
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
class LoRALinear(LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
# ↓ this part is for pretrained weights
in_features: int,
out_features: int,
# ↓ the remaining part is for LoRA
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs: Any,
):
"""LoRA wrapper around linear class.
This class has three weight matrices:
1. Pretrained weights are stored as `self.linear.weight`
2. LoRA A matrix as `self.lora_A`
3. LoRA B matrix as `self.lora_B`
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
Args:
in_features: number of input features of the pretrained weights
out_features: number of output features of the pretrained weights
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
"""
super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(torch.empty((r, in_features)))
self.lora_B = nn.Parameter(torch.empty((out_features, r)))
self.scaling = self.lora_alpha / self.r
self.reset_parameters()
def reset_parameters(self) -> None:
"""Reset all the weights, even including pretrained ones."""
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def get_lora_AB(self) -> torch.Tensor:
"""Return merged lora_A and lora_B matrices with the same shape as the pretrained weights."""
return (self.lora_B @ self.lora_A) * self.scaling
def merge(self) -> None:
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
if self.r > 0 and not self.merged:
pretrained_dtype = self.linear.weight.data.dtype
lora_data = self.get_lora_AB()
# if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result
if pretrained_dtype == torch.uint8:
import bitsandbytes as bnb
weight = self.linear.weight
# dequantize the pretrained weights
weight_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state).to(lora_data.dtype)
# add pretrained and LoRA weights
weight_data += lora_data
# assign updated weights and quantize by moving to CUDA device
self.linear.weight = bnb.nn.Params4bit(weight_data, requires_grad=False, **weight.__dict__)
self.linear.weight.cuda(weight.device)
else:
# self.linear might be on CPU and lora_data on CUDA
# the inplace add will preserve the dtype of linear.weight
self.linear.weight.data += lora_data.to(device=self.linear.weight.data.device)
self.merged = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
# if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
pretrained = self.linear(x)
if self.r == 0 or self.merged:
return pretrained
lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return pretrained + lora
class LoRAQKVLinear(LoRALinear):
# LoRA implemented in a dense layer
def __init__(
self,
# ↓ this part is for pretrained weights
in_features: int,
out_features: int,
# ↓ the remaining part is for LoRA
head_size: int,
n_head: int,
n_query_groups: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
**kwargs: Any,
):
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
This class has three weight matrices:
1. Pretrained weights are stored as `self.linear.weight`
2. LoRA A matrix as `self.lora_A`
3. LoRA B matrix as `self.lora_B`
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
Args:
in_features: number of input features of the pretrained weights
out_features: number of output features of the pretrained weights
head_size: size of a single attention head
n_head: number of attention heads
n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
"""
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.head_size = head_size
self.n_head = n_head
self.n_query_groups = n_query_groups
if isinstance(enable_lora, bool):
enable_lora = [enable_lora] * 3
assert len(enable_lora) == 3
self.enable_lora = enable_lora
# Actual trainable parameters
# To better understand initialization let's imagine that we have such parameters:
# ⚬ in_features: 128 (embeddings_size)
# ⚬ out_features: 384 (3 * embedding_size)
# ⚬ r: 2
# ⚬ enable_lora: [True, False, True]
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128)
enable_q, enable_k, enable_v = enable_lora
# qkv_shapes will be used to split a tensor with weights correctly
qkv_shapes = (
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
# might not be equal to `head_size * n_head`, thus we use it directly here
head_size * n_head * enable_q,
head_size * n_query_groups * enable_k,
head_size * n_query_groups * enable_v,
)
self.qkv_shapes = [s for s in qkv_shapes if s]
self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2))
# Notes about shapes above
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
# F.linear function weights are automatically transposed. In addition conv1d requires channels to
# be before seq length
# - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
# 128*2; 2 tells to have two channels per group for group convolution
# Scaling:
# This balances the pretrained model`s knowledge and the new task-specific adaptation
# https://lightning.ai/pages/community/tutorial/lora-llm/
# So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
# alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
# tune these values to your needs. This value can be even slightly greater than 1.0!
# https://github.com/cloneofsimo/lora
self.scaling = self.lora_alpha / self.r
self.reset_parameters()
@property
def lora_ind(self) -> torch.Tensor:
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
enable_q, enable_k, enable_v = self.enable_lora
q_embd_size = self.head_size * self.n_head
kv_embd_size = self.head_size * self.n_query_groups
lora_ind = []
if enable_q:
lora_ind.extend(range(0, q_embd_size))
if enable_k:
lora_ind.extend(range(q_embd_size, q_embd_size + kv_embd_size))
if enable_v:
lora_ind.extend(range(q_embd_size + kv_embd_size, self.linear.out_features))
self.register_buffer(
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
)
return self._lora_ind
def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad the last dimension of weight updates with zeros.
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
then the weights update should be:
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
[....................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
↑ ↑ ↑
________________________________________
| query | key | value |
----------------------------------------
Args:
x: tensor with weights update that will be padded with zeros if necessary
Returns:
A tensor with weight updates and zeros for deselected q, k or v
"""
# we need to do zero padding only if LoRA is disabled for one of QKV matrices
if all(self.enable_lora):
return x
# Let's image that:
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
# ⚬ embeddings_size: 128
# ⚬ self.linear.out_features: 384 (3 * embeddings_size)
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384)
if result.device.type == "mps":
result[..., self.lora_ind] = x
return result
else:
return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384)
def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
If the number of heads is equal to the number of query groups - grouped queries are disabled
(see scheme in `litgpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
input and weight matrices will be split in equally sized parts and applied separately (like having multiple
conv layers side by side).
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
apply each part of the weight matrix to the corresponding input's part and concatenate the result.
Args:
input: input matrix of shape (B, C, T)
weight: weight matrix of shape (C_output, rank, 1).
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
Returns:
A tensor with a shape (B, C_output, T)
"""
if self.n_head == self.n_query_groups:
return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)
# Notation:
# ⚬ N: number of enabled LoRA layers (self.enable_lora)
# ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
# ⚬ r: rank of all LoRA layers (equal in size)
input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
return torch.cat(
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)],
dim=1, # (B, C_output', T)
) # (B, C_output, T)
def get_lora_AB(self) -> torch.Tensor:
"""Return merged lora_A and lora_B matrices with the same shape as the pretrained weights."""
# Let's assume that:
# ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
# ⚬ self.lora_A.data: (4, 128)
# ⚬ self.lora_B.data: (256, 2)
lora = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128)
def merge(self) -> None:
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
if self.r > 0 and any(self.enable_lora) and not self.merged:
super().merge()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Do the forward pass.
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
Args:
x: input tensor of shape (batch_size, context_length, embedding_size)
Returns:
Output tensor of shape (batch_size, context_length, 3 * embedding_size)
"""
# Let's assume that:
# ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
# ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
# ⚬ self.lora_A.data: (4, 128)
# ⚬ self.lora_B.data: (256, 2)
# if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
pretrained = self.linear(x)
if self.r == 0 or not any(self.enable_lora) or self.merged:
return pretrained
after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
# For F.conv1d:
# ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
return pretrained + lora
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
Args:
model: model with LoRA layers
bias:
``"none"``: all bias weights will be frozen,
``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
``"all"``: all bias weights will be unfrozen.
Raises:
NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
"""
# freeze all layers except LoRA's
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
# depending on the `bias` value unfreeze bias weights
if bias == "none":
return
if bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_filter(key: str, value: Any) -> bool:
return "lora_" in key
@dataclass
class Config(BaseConfig):
"""
Args:
lora_r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
lora_*: whether to apply LoRA to the specified weights or not
"""
lora_r: int = 0
lora_alpha: int = 1
lora_dropout: float = 0.0
lora_query: bool = False
lora_key: bool = False
lora_value: bool = False
lora_projection: bool = False
lora_mlp: bool = False
lora_head: bool = False
@property
def mlp_class(self) -> Type:
return getattr(litgpt.lora, self.mlp_class_name)
class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = create_lora_linear(
config,
config.n_embd,
config.padded_vocab_size,
bias=config.lm_head_bias,
use_r=config.lora_head,
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.max_seq_length = self.config.block_size
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, LoRALinear):
module.reset_parameters()
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
self.mlp = config.mlp_class(config)
class CausalSelfAttention(BaseCausalSelfAttention):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
# key, query, value projections for all heads, but in a batch
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = LoRAQKVLinear(
in_features=config.n_embd,
out_features=shape,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
enable_lora=(config.lora_query, config.lora_key, config.lora_value),
bias=config.bias or config.attn_bias,
# for MQA/GQA support
head_size=config.head_size,
n_head=config.n_head,
n_query_groups=config.n_query_groups,
)
# output projection
self.proj = create_lora_linear(
config,
config.head_size * config.n_head,
config.n_embd,
use_r=config.lora_projection,
)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def create_lora_linear(
config: Config,
in_size: int,
out_size: int,
bias: Optional[Union[float, bool]] = None,
use_r: Optional[bool] = None,
) -> LoRALinear:
if bias is None:
bias = config.bias
if use_r is None:
use_r = config.lora_mlp
return LoRALinear(
in_size,
out_size,
bias=bias,
r=(config.lora_r if use_r else 0),
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
)
class GptNeoxMLP(litgpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = create_lora_linear(config, config.n_embd, config.intermediate_size)
self.proj = create_lora_linear(config, config.intermediate_size, config.n_embd)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc.weight": "fc.linear.weight",
"fc.bias": "fc.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class LLaMAMLP(litgpt.model.LLaMAMLP):
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
nn.Module.__init__(self)
self.intermediate_size = intermediate_size or config.intermediate_size
self.fc_1 = create_lora_linear(config, config.n_embd, self.intermediate_size)
self.fc_2 = create_lora_linear(config, config.n_embd, self.intermediate_size)
self.proj = create_lora_linear(config, self.intermediate_size, config.n_embd)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc_1.weight": "fc_1.linear.weight",
"fc_1.bias": "fc_1.linear.bias",
"fc_2.weight": "fc_2.linear.weight",
"fc_2.bias": "fc_2.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)
class LLaMAMoE(litgpt.model.LLaMAMoE):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.gate = create_lora_linear(config, config.n_embd, config.n_expert, bias=False)
self.experts = nn.ModuleList(
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"gate.weight": "gate.linear.weight"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def merge_lora_weights(model: GPT) -> None:
"""Merge LoRA weights into the full-rank weights to speed up inference."""
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
================================================
FILE: litgpt/model.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Full definition of a decoder-only transformer-based language model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""
import math
from functools import partial
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import Self
from litgpt.config import Config
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
class GPT(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.max_seq_length = self.config.block_size
@property
def max_seq_length(self) -> int:
return self._max_seq_length
@max_seq_length.setter
def max_seq_length(self, value: int) -> None:
"""
When doing inference, the sequences used might be shorter than the model's context length.
This allows setting a smaller number to avoid allocating unused memory
"""
if value > self.config.block_size:
raise ValueError(
f"Cannot attend to {value}, block size is only {self.config.block_size}."
" This is likely because the input text exceeds the supported context length of this model."
)
self._max_seq_length = value
if not hasattr(self, "cos"):
# first call
cos, sin = self.rope_cache()
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# override
elif value != self.cos.size(0):
self.cos, self.sin = self.rope_cache(device=self.cos.device)
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
# if the kv cache is expected
if self.mask_cache is not None and self.mask_cache.shape[-1] < value:
print(
f"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!"
)
def reset_parameters(self) -> None:
# Trigger resetting the rope-cache
self.cos, self.sin = self.rope_cache(device=self.cos.device)
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
if isinstance(module, GroupedTopkRouter):
torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
idx: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
lm_head_chunk_size: int = 0,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
If `input_pos` is provided, the KV cache uses K and V vectors for
positions smaller than entries in `input_pos`. For efficiency, pass
`input_pos_maxp1` as `max(input_pos) + 1` if already available from
your forward algorithm. This slices the KV cache buffers and speeds
up multi-head attention.
Without `input_pos_maxp1`, the computation uses the full KV cache
(`max_seq_length`) with masking applied. Note that inferring
`input_pos_maxp1` from `input_pos` causes graph breaks and prevents
compilation.
Args:
idx: Token indices of input sequences, shape `(B, T)`, where `B`
is batch size.
input_pos: Optional. Positions of input tokens. The default is
`arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index).
input_pos_maxp1: Optional. See above.
lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final
`lm_head` computation is done in chunks of this size.
Returns:
Logit outputs, shape `(B, T, config.padded_vocab_size)`. If
`lm_head_chunk_size > 0`, this is a list of chunks of shape
`(B, lm_head_chunk_size, config.padded_vocab_size)`, the final
entry can be shorter.
"""
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
if input_pos is not None: # use the kv cache
if input_pos.dim() > 2:
# otherwise, things go wrong in `apply_rope`
raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}")
if input_pos.shape[-1] != T:
raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same")
cos = batched_index_select(self.cos, 0, input_pos)
sin = batched_index_select(self.sin, 0, input_pos)
if input_pos.dim() == 1:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = batched_index_select(self.mask_cache, 2, input_pos)
if mask.dim() > 4:
# the mask cache has a batch dim of 1 in addition to the one
# we get if input_pos has a batch dimension
mask = mask.view(*(mask.shape[0:1] + mask.shape[2:]))
if input_pos_maxp1 is not None:
# Shorten final dimension so it just covers all `input_pos` entries
if input_pos_maxp1 > self.max_seq_length:
raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})")
mask = mask[..., :input_pos_maxp1]
else:
# unsqueeze to have a batch dimension
cos = self.cos[:T].unsqueeze(0)
sin = self.sin[:T].unsqueeze(0)
# `cos`, `sin` have shape (1, T, config.rope_n_elem)
mask = None # defaults to causal mask
input_pos_maxp1 = None
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
if self.config.scale_embeddings:
x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)
for block_idx, block in enumerate(self.transformer.h):
if self.config.rope_indices is not None:
x = block(
x,
cos[..., self.config.rope_indices[block_idx]],
sin[..., self.config.rope_indices[block_idx]],
mask,
input_pos,
input_pos_maxp1,
)
else:
x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)
x = self.transformer.ln_f(x)
clamp_head = (
partial(do_softcapping, thresh=self.config.final_logit_softcapping)
if self.config.final_logit_softcapping is not None
else nn.Identity()
)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [clamp_head(self.lm_head(x_i)) for x_i in x.split(lm_head_chunk_size, dim=1)]
else:
return clamp_head(self.lm_head(x)) # (B, T, padded_vocab_size)
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_adjustments is None:
extra_config = None
else:
# Check for mutually exclusive parameter sets
llama3_params = ["low_freq_factor", "high_freq_factor"]
yarn_params = ["beta_fast", "beta_slow"]
has_llama3 = any(param in self.config.rope_adjustments for param in llama3_params)
has_yarn = any(param in self.config.rope_adjustments for param in yarn_params)
if has_llama3 and has_yarn:
raise ValueError(
"RoPE adjustments cannot contain both Llama3 parameters (low_freq_factor, high_freq_factor) "
"and YaRN parameters (beta_fast, beta_slow). These are mutually exclusive."
)
# Llama3-style RoPE
if has_llama3:
adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"]
params_present = [param in self.config.rope_adjustments for param in adjusted_params_required]
if all(params_present):
extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}
else:
missing_params = [
param for param, present in zip(adjusted_params_required, params_present) if not present
]
raise ValueError(
f"The following Llama3 RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. "
"All Llama3 parameters must be specified together."
)
# YaRN-style RoPE
elif has_yarn:
# Required: factor, beta_fast, beta_slow, original_max_seq_len
# Optional: mscale, mscale_all_dim
yarn_required_params = ["factor", "beta_fast", "beta_slow", "original_max_seq_len"]
params_present = [param in self.config.rope_adjustments for param in yarn_required_params]
if not all(params_present):
missing_params = [
param for param, present in zip(yarn_required_params, params_present) if not present
]
raise ValueError(
f"The following YaRN RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. "
"All YaRN required parameters must be specified together."
)
extra_config = {name: self.config.rope_adjustments[name] for name in yarn_required_params}
# Add optional YaRN parameters
for param in ["mscale", "mscale_all_dim"]:
if param in self.config.rope_adjustments:
extra_config[param] = self.config.rope_adjustments[param]
# Linear or standard RoPE
elif "factor" in self.config.rope_adjustments:
# linear RoPE
adjusted_params_required = ["factor"]
extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}
else:
extra_config = None # uses standard RoPE
return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
extra_config=extra_config,
rope_local_base_freq=self.config.rope_local_base_freq,
)
def rope_cache_length(self) -> int:
"""
Extract the head dimension (n_elem) from RoPE cache regardless of shape.
The RoPE cache can have different shapes depending on model configuration:
- Standard RoPE: (seq_len, n_elem) - 2D tensor
- Dual RoPE (local/global): (seq_len, n_elem, 2) - 3D tensor
Returns:
int: n_elem (head dimension for RoPE)
"""
return self.cos.size(1)
def set_kv_cache(
self,
batch_size: int,
max_seq_length: Optional[int] = None,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if rope_cache_length is None:
rope_cache_length = self.rope_cache_length()
if max_seq_length is None:
max_seq_length = self.max_seq_length
# initialize the kv cache for all blocks
for block in self.transformer.h:
block.attn.kv_cache = block.attn.build_kv_cache(
batch_size,
max_seq_length,
rope_cache_length,
device,
dtype,
)
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
# passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
self.mask_cache = build_mask_cache(max_seq_length, device)
def clear_kv_cache(self) -> None:
self.mask_cache = None
for block in self.transformer.h:
block.attn.kv_cache = None
class Block(nn.Module):
def __init__(
self,
config: Config,
block_idx: int,
) -> None:
super().__init__()
if not config.parallel_residual and config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration"
" (non-parallel residual and shared attention norm)."
)
self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = (
CausalSelfAttention(config, block_idx)
if not config.latent_attention
else MultiheadLatentAttention(config, block_idx)
)
self.post_attention_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
)
self.norm_2 = (
nn.Identity()
if not config.norm_2
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
)
self.mlp = config.mlp_class(config)
if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace:
self.mlp = LLaMAMLP(config)
self.post_mlp_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
)
self.config = config
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
) -> torch.Tensor:
"""
Non-parallel residual Parallel residual
┌─ x ┌─ x ──────────────────┐ Note: if `shared_attention_norm` is True,
│ ↓ │ ↓ ↓ the output from `norm_1` is reused
│ norm_1 │ norm_1 ───────► norm_2
│ ↓ │ ↓ ↓
│ attn │ attn MLP
│ ↓ │ ↓ ↓
| post_attn_norm | post_attn_norm post_mlp_norm
| ↓ | ↓ ↓
┌─ └► + └► + ◄─────────────────┘
| ↓
│ norm_2
│ ↓
│ MLP
│ ↓
| post_mlp_norm
| ↓
└───► +
"""
x_normed = self.norm_1(x)
attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1)
attention_output = self.post_attention_norm(attention_output)
if self.config.parallel_residual:
if not self.config.shared_attention_norm:
x_normed = self.norm_2(x)
x = attention_output + x
else:
x = attention_output + x
x_normed = self.norm_2(x)
return self.post_mlp_norm(self.mlp(x_normed)) + x
class CausalSelfAttention(nn.Module):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__()
# key, query and value projections for all heads, but in a batch
self.qkv = nn.Linear(
config.n_embd,
(config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries
bias=config.bias or config.attn_bias,
)
# output projection
self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.apply_sliding_window_attention = False
if config.sliding_window_size is not None and config.sliding_window_indices is not None:
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]
if config.norm_qk:
norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
norm_k_size = (
config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
)
self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps)
self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps)
else:
self.norm_q = self.norm_k = None
if config.rope_adjustments is not None:
mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None)
scaling_factor = config.rope_adjustments.get("factor", None)
if mscale_all_dim and scaling_factor: # YaRN
self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
else:
self.mscale = 1.0
else:
self.mscale = 1.0
self.config = config
self.block_idx = block_idx
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
) -> torch.Tensor:
# Notation:
# - B | batch size
# - T | time-step (sequence length)
# - C | model's embeddings size (n_embd)
# - C* | attentions's embeddings size
# - hs | head size
# - nh_(q,k,v) | number of heads for query, key and value
# - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
# alternative notation: num_kv_groups = n_query_groups
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ │ │ │
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
head_size = self.config.head_size
n_head = self.config.n_head
n_query_groups = self.config.n_query_groups
rope_n_elem = self.config.rope_n_elem
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value`
# instead of individually multiplying the input `x` with the respective weight matrices.
qkv = self.qkv(x) # (B, T, 3xC*)
# Define query, key and value sizes.
# If grouped/multi query is enabled, these sizes are not equal (see the diagram above).
query_size = n_head * head_size
key_size = value_size = n_query_groups * head_size
# Split qkv into query, key and value matrices.
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)
if self.config.norm_qk and self.config.norm_qk_type == "olmo2":
q = self.norm_q(q)
k = self.norm_k(k)
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
# embedding size (C) into num_heads (nh) and head_size (hs).
# The original GQA paper is followed here and the term query groups is used.
# alternative notation: Query groups are also referred to as KV groups.
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)
k = k.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs)
v = v.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs)
# The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are
# multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector
# of size `hs`.
q = q.transpose(1, 2) # (B, nh_q, T, hs)
k = k.transpose(1, 2) # (B, nh_k, T, hs)
v = v.transpose(1, 2) # (B, nh_v, T, hs)
if self.config.norm_qk and self.config.norm_qk_type == "default":
q = self.norm_q(q)
k = self.norm_k(k)
# Unlike standard positional embeddings rotary embeddings must be applied at every layer.
if self.config.rope_interleave:
q_roped = apply_rope_interleave(q[..., :rope_n_elem], cos, sin)
k_roped = apply_rope_interleave(k[..., :rope_n_elem], cos, sin)
else:
q_roped = apply_rope(q[..., :rope_n_elem], cos, sin)
k_roped = apply_rope(k[..., :rope_n_elem], cos, sin)
q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs)
k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs)
# Apply kv-cache during inference.
if input_pos is not None:
if not isinstance(self.kv_cache, KVCache):
raise TypeError("You need to call `gpt.set_kv_cache()`")
k, v = self.kv_cache(input_pos, k, v)
if self.apply_sliding_window_attention:
actual_kv_len = k.size(2)
if mask is not None and mask.size(-1) != actual_kv_len:
mask = mask[..., :actual_kv_len]
if input_pos_maxp1 is not None:
# Subselect along sequence dimension
k = k[..., :input_pos_maxp1, :]
v = v[..., :input_pos_maxp1, :]
# k, v: (B, nh_k, input_pos_maxp1, hs)
# If input_pos_maxp1 is None -> max_seq_length
# Grouped queries: balance the number of heads across all three matrices.
# NOTE: flash attention requires it in training mode.
# Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
if n_query_groups != n_head and (input_pos is None or n_query_groups != 1):
q_per_kv = n_head // n_query_groups
k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
if self.apply_sliding_window_attention:
"""
Global Window Sliding window Sliding window
attention mask + bias = attention mask
┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐
│ True False False False │ │ True True True True │ │ True False False False │
│ True True False False │ │ True True True True │ │ True True False False │
│ True True True False │ │ False True True True │ │ False True True False │
│ True True True True │ │ False False True True │ │ False False True True │
└────────────────────────┘ └───────────────────────┘ └─────────────────────────┘
"""
if input_pos is None:
if mask is None:
mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), float("-inf"))
mask = mask.view(1, 1, *mask.shape)
sliding_window_mask = torch.full((T, T), float("-inf"), dtype=q.dtype, device=q.device)
for i in range(T):
window_start = max(0, i - self.config.sliding_window_size + 1)
sliding_window_mask[i, window_start : i + 1] = 0.0
sliding_window_mask = sliding_window_mask.view(1, 1, T, T)
mask = sliding_window_mask
# Efficient attention using Flash Attention CUDA kernels.
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
y = self.scaled_dot_product_attention(q, k, v, mask)
# Re-assemble all head outputs side by side.
y = y.reshape(B, T, head_size * n_head)
# Output projection.
return self.proj(y) # (B, T, C)
def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size)
scale = scale * self.mscale * self.mscale
# with softcapping we cannot use SDPA
if self.config.attention_logit_softcapping is not None:
scores = q @ k.mT * scale
scores = do_softcapping(scores, self.config.attention_logit_softcapping)
if mask is None:
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
scores = scores + mask
scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)
y = scores @ v
else:
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)
def build_kv_cache(
self,
batch_size: int,
max_seq_length: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "KVCache":
if self.apply_sliding_window_attention and self.config.sliding_window_size is not None:
effective_cache_size = min(max_seq_length, self.config.sliding_window_size)
else:
effective_cache_size = max_seq_length
v_shape = (batch_size, self.config.n_query_groups, effective_cache_size, self.config.head_size)
if rope_cache_length is None:
if self.config.rotary_percentage != 1.0:
raise TypeError(
"Please pass the `rope_cache_length` parameter. "
"Use `rope_cache_length=model.rope_cache_length()` to extract it automatically."
)
k_shape = v_shape
else:
k_shape = (
batch_size,
self.config.n_query_groups,
effective_cache_size,
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
)
return KVCache(
k_shape,
v_shape,
device=device,
dtype=dtype,
is_sliding_window=self.apply_sliding_window_attention,
sliding_window_size=self.config.sliding_window_size if self.apply_sliding_window_attention else None,
)
def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with legacy checkpoints."""
for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.{attr}"
current_key = f"{prefix}qkv.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class MultiheadLatentAttention(nn.Module):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__()
self.q_a_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=config.attn_bias)
self.q_a_norm = RMSNorm(config.q_lora_rank, eps=config.norm_eps)
self.q_b_proj = nn.Linear(config.q_lora_rank, config.n_head * config.qk_head_dim, bias=config.bias)
self.kv_a_proj_with_mqa = nn.Linear(
config.n_embd, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attn_bias
)
self.kv_a_norm = RMSNorm(config.kv_lora_rank, eps=config.norm_eps)
self.kv_b_proj = nn.Linear(
config.kv_lora_rank,
config.n_query_groups * (config.qk_nope_head_dim + config.v_head_dim),
bias=config.bias,
)
# output projection
self.proj = nn.Linear(config.n_head * config.v_head_dim, config.n_embd, bias=config.bias)
# disabled by default
self.kv_cache: Optional[KVCache] = None
if config.rope_adjustments is not None:
mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None)
scaling_factor = config.rope_adjustments.get("factor", None)
if mscale_all_dim and scaling_factor: # YaRN
self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
else:
self.mscale = 1.0
else:
self.mscale = 1.0
self.config = config
self.block_idx = block_idx
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
) -> torch.Tensor:
# Notation:
# - B | batch size
# - T | time-step (sequence length)
# - C | model's embeddings size (n_embd)
# - C* | attentions's embeddings size
# - hs | head size
# - nh_(q,k,v) | number of heads for query, key and value
# - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
# alternative notation: num_kv_groups = n_query_groups
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
q = self.q_b_proj(self.q_a_norm(self.q_a_proj(x))) # (B, T, n_head * qk_head_dim)
q = q.view(B, T, -1, self.config.qk_head_dim) # (B, T, n_head, qk_head_dim)
q = q.transpose(1, 2) # (B, n_head, T, qk_head_dim)
q_pass, q_rot = torch.split(q, [self.config.qk_nope_head_dim, self.config.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(x) # (B, T, kv_lora_rank + qk_rope_head_dim)
k_pass, k_rot = torch.split(compressed_kv, [self.config.kv_lora_rank, self.config.qk_rope_head_dim], dim=-1)
k_pass = self.kv_b_proj(self.kv_a_norm(k_pass))
k_pass = k_pass.view(B, T, self.config.n_query_groups, -1)
k_pass = k_pass.transpose(1, 2)
k_pass, v = torch.split(k_pass, [self.config.qk_nope_head_dim, self.config.v_head_dim], dim=-1)
k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim)
# Unlike standard positional embeddings rotary embeddings must be applied at every layer.
if self.config.rope_interleave:
q_roped = apply_rope_interleave(q_rot, cos, sin)
k_roped = apply_rope_interleave(k_rot, cos, sin)
else:
q_roped = apply_rope(q_rot, cos, sin)
k_roped = apply_rope(k_rot, cos, sin)
k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim)
q = torch.cat((q_pass, q_roped), dim=-1)
k = torch.cat((k_pass, k_roped), dim=-1)
# Apply kv-cache during inference.
if input_pos is not None:
if not isinstance(self.kv_cache, KVCache):
raise TypeError("You need to call `gpt.set_kv_cache()`")
k, v = self.kv_cache(input_pos, k, v)
if input_pos_maxp1 is not None:
# Subselect along sequence dimension
k = k[..., :input_pos_maxp1, :]
v = v[..., :input_pos_maxp1, :]
# k, v: (B, nh_k, input_pos_maxp1, hs)
# If input_pos_maxp1 is None -> max_seq_length
# Grouped queries: balance the number of heads across all three matrices.
# NOTE: flash attention requires it in training mode.
# Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
q_per_kv = self.config.n_head // self.config.n_query_groups
k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
# Efficient attention using Flash Attention CUDA kernels.
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
y = self.scaled_dot_product_attention(q, k, v, mask)
# Re-assemble all head outputs side by side.
y = y.reshape(B, T, self.config.n_head * self.config.v_head_dim)
# Output projection.
return self.proj(y) # (B, T, C)
def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.qk_head_dim)
scale = scale * self.mscale * self.mscale
# with softcapping we cannot use SDPA
if self.config.attention_logit_softcapping is not None:
scores = q @ k.mT * scale
scores = do_softcapping(scores, self.config.attention_logit_softcapping)
if mask is None:
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
scores = scores + mask
scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)
y = scores @ v
else:
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)
def build_kv_cache(
self,
batch_size: int,
max_seq_length: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "KVCache":
v_shape = (batch_size, self.config.n_head, max_seq_length, self.config.v_head_dim)
k_shape = (batch_size, self.config.n_head, max_seq_length, self.config.qk_head_dim)
if rope_cache_length is not None:
print("Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!")
if self.config.rotary_percentage != 1.0:
print("Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!")
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
class GptNeoxMLP(nn.Module):
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
super().__init__()
self.intermediate_size = intermediate_size or config.intermediate_size
self.fc = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = F.gelu(x, approximate=self.config.gelu_approximate)
return self.proj(x)
class LLaMAMLP(nn.Module):
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
super().__init__()
self.intermediate_size = intermediate_size or config.intermediate_size
self.fc_1 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = F.silu(x_fc_1) * x_fc_2
return self.proj(x)
class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = F.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)
class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.gate = (
nn.Linear(config.n_embd, config.n_expert, bias=False)
if not config.n_expert_groups
else GroupedTopkRouter(config)
)
self.experts = nn.ModuleList(
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
)
if config.n_shared_expert:
self.shared_experts = LLaMAMLP(
config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert
)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
See also figure 1 in https://arxiv.org/abs/2211.15841
"""
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
residual_x = x.clone()
x = x.view(-1, C) # (B*T, C)
if not self.config.n_expert_groups:
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
else:
probs, indices = self.gate(x)
if self.config.routed_scaling_factor != 1.0:
probs = probs * self.config.routed_scaling_factor
masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
for mask, expert in zip(masks, self.experts):
token_idx, expert_idx = torch.where(mask)
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
y = y.view(B, T, C)
if self.config.n_shared_expert:
y = y + self.shared_experts(residual_x)
return y
class GroupedTopkRouter(nn.Module):
"""
Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py.
DeepseekV3TopkRouter class.
"""
def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd))
self.register_buffer("e_score_correction_bias", torch.zeros(config.n_expert))
@torch.no_grad()
def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor:
scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
.topk(self.config.n_topk_scores_per_group, dim=-1)[0] # Top k scores for each group
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
.reshape(-1, self.config.n_expert)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1]
return topk_indices
def forward(self, x: torch.Tensor) -> torch.Tensor:
router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.config.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
return topk_weights, topk_indices
# ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def build_rope_cache(
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
rope_local_base_freq: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Enhanced Transformer with Rotary Position Embedding.
Args:
seq_len (int): Sequence length.
n_elem (int): Number of elements (head dimension).
device (torch.device, optional): Device for tensor allocations.
base (int, optional): Base for computing inverse frequencies.
condense_ratio (int, optional): Ratio to condense the position indices.
extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
Shapes are `(seq_len, n_elem)`.
"""
# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
# Initialize attention scaling factor (modified for YaRN)
attention_scaling = 1.0
if extra_config is not None:
factor = extra_config["factor"]
# Check YaRN first (has beta_fast/beta_slow)
if "beta_fast" in extra_config or "beta_slow" in extra_config:
# YaRN-style RoPE scaling
beta_fast = extra_config["beta_fast"]
beta_slow = extra_config["beta_slow"]
original_max_seq_len = extra_config["original_max_seq_len"]
# Calculate attention scaling factor based on mscale and mscale_all_dim
mscale = extra_config.get("mscale")
mscale_all_dim = extra_config.get("mscale_all_dim")
if mscale and mscale_all_dim:
attention_scaling = yarn_get_mscale(factor, mscale) / yarn_get_mscale(factor, mscale_all_dim)
elif mscale_all_dim:
attention_scaling = yarn_get_mscale(factor, mscale_all_dim)
elif mscale:
attention_scaling = yarn_get_mscale(factor, mscale)
# else: attention_scaling remains 1.0
# Create two frequency sets: extrapolation (unscaled) and interpolation (scaled)
pos_freqs = base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)
theta_extrapolation = 1.0 / pos_freqs
theta_interpolation = 1.0 / (factor * pos_freqs)
# Find correction range based on rotation counts
# Inverse dimension formula to find dimension based on number of rotations
def find_correction_dim(num_rotations, dim, base_val, max_pos):
return (dim * math.log(max_pos / (num_rotations * 2 * math.pi))) / (2 * math.log(base_val))
low_dim = find_correction_dim(beta_fast, n_elem, base, original_max_seq_len)
high_dim = find_correction_dim(beta_slow, n_elem, base, original_max_seq_len)
# Apply truncation if specified
if extra_config.get("truncate", True):
low_dim = math.floor(low_dim)
high_dim = math.ceil(high_dim)
low_dim = max(low_dim, 0)
high_dim = min(high_dim, n_elem // 2 - 1)
# Create linear ramp factor for blending
dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32)
if low_dim == high_dim:
high_dim += 0.001 # Prevent singularity
linear_func = (dim_range - low_dim) / (high_dim - low_dim)
ramp_func = torch.clamp(linear_func, 0.0, 1.0)
# Blend extrapolation and interpolation frequencies
# ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled)
theta_extrapolation_factor = ramp_func
theta = (
theta_interpolation * (1 - theta_extrapolation_factor)
+ theta_extrapolation * theta_extrapolation_factor
)
elif "original_max_seq_len" in extra_config:
# Llama3-style RoPE scaling
orig_context_len = extra_config["original_max_seq_len"]
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]
wavelen = 2 * torch.pi / theta
ratio = orig_context_len / wavelen
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)
# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
theta = adjusted_theta
else:
# Linear scaling fallback
theta = theta / factor
# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
# If `n_elem` is odd, the final dimension of `idx_theta` has size
# `n_elem + 1`, so need to cut something off.
# Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave
# `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to
# broadcasting. If we shorten `idx_theta`, unit tests comparing to
# Hugging Face fail.
# https://github.com/huggingface/transformers/issues/35233
if idx_theta.shape[-1] > n_elem > 1:
idx_theta = idx_theta[..., :n_elem]
# if rope_local_base_freq is given, have a separate rope value for local embedding
# For now, we use default RoPE for local embedding
if rope_local_base_freq is not None:
local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
local_idx_theta = torch.outer(seq_idx, local_theta)
local_idx_theta = local_idx_theta.repeat(1, 2)
if local_idx_theta.shape[-1] > n_elem > 1:
local_idx_theta = local_idx_theta[..., :n_elem]
idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1)
cos = torch.cos(idx_theta) * attention_scaling
sin = torch.sin(idx_theta) * attention_scaling
return cos, sin
def batched_index_select(t, dim, idx):
"""index_select for batched index and unbatched t"""
if idx.dim() == 1:
return torch.index_select(t, dim, idx)
*batch_shape, idx_size = idx.shape
res = torch.index_select(t, dim, idx.reshape(-1)) # flat index
# split out single batch idx
res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :])
if dim > 0:
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
# unflatten batch dims
res = res.view(*batch_shape, *res.shape[1:])
return res
def batched_index_copy_(t, dim, idx, val):
"""Index copy for batched t, idx, val"""
if t.device.type == "mps":
# Normalize negative dimensions
if dim < 0:
dim = t.dim() + dim
if idx.dim() == 1:
idx_shape = [1] * val.dim()
idx_shape[dim] = -1
idx_expanded = idx.view(*idx_shape)
idx_expanded = idx_expanded.expand_as(val)
t.scatter_(dim, idx_expanded, val)
return t
elif idx.dim() == 2:
assert dim != 0, "Cannot index the batch dimension"
batch_size = idx.size(0)
idx_size = idx.size(1)
assert batch_size == t.size(0) == val.size(0)
idx_shape = [batch_size] + [1] * (val.dim() - 1)
idx_shape[dim] = idx_size
idx_expanded = idx.view(*idx_shape)
idx_expanded = idx_expanded.expand_as(val)
t.scatter_(dim, idx_expanded, val)
return t
else:
raise NotImplementedError(f"idx.dim() == {idx.dim()} not supported")
else:
if idx.dim() == 1:
return t.index_copy_(dim, idx, val)
assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}"
assert dim != 0, f"cannot index batch dim {dim=}"
batch_size, idx_size = idx.shape
assert batch_size == t.size(0)
assert batch_size == val.size(0)
# if we can view the batch and indexed dimensions together, we could
# do index trickery. This is, sadly, not the case for kvcache so we
# fall back to for loop
for i in range(batch_size):
unbatched_dim = dim if dim < 0 else dim - 1
t[i].index_copy_(unbatched_dim, idx[i], val[i])
return t
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""
Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch
dimension.
Args:
x: Input tensor, `(B, ..., T, head_size)`
cos: Cached cosines, `(B, T, head_size)` or `(1, T, head_size)`
sin: Cached sines, `(B, T, head_size)` or `(1, T, head_size)`
Returns:
Encoded tensor, `(B, ..., T, head_size)`
"""
if cos.dim() != 3:
raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}")
if cos.shape != sin.shape:
raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}")
head_size_half = x.size(-1) // 2
x1 = x[..., :head_size_half] # (B, ..., T, head_size/2)
x2 = x[..., head_size_half:] # (B, ..., T, head_size/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size)
dims_diff = x.dim() - cos.dim()
if dims_diff > 0:
# Ensure that shapes of `x`, `cos`, `sin` align
new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:]
cos = cos.view(*new_shape)
sin = sin.view(*new_shape)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
def apply_rope_interleave(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embeddings with interleaved tensor layout.
This version rearranges the input tensor to group even/odd indices separately
before applying the standard RoPE rotation, matching HuggingFace's
apply_rotary_pos_emb_interleave behavior.
Args:
x: Input tensor of shape (..., seq_len, head_dim)
cos: Cosine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim)
sin: Sine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim)
Returns:
Tensor with RoPE applied, same shape as input
"""
if cos.dim() != 3:
raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}")
if cos.shape != sin.shape:
raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}")
# Rearrange tensor to group even/odd indices: [x0,x1,x2,x3,...] -> [x0,x2,x4,...,x1,x3,x5,...]
*batch_dims, d = x.shape
x = x.view(*batch_dims, d // 2, 2).transpose(-1, -2).reshape(*batch_dims, d)
# Standard rotation logic (same as apply_rope)
head_size_half = x.size(-1) // 2
x1 = x[..., :head_size_half]
x2 = x[..., head_size_half:]
rotated = torch.cat((-x2, x1), dim=-1)
# Auto-detect dimension mismatch and reshape cos/sin
dims_diff = x.dim() - cos.dim()
if dims_diff > 0:
new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:]
cos = cos.view(*new_shape)
sin = sin.view(*new_shape)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor:
return torch.tanh(x / thresh) * thresh
class KVCache(nn.Module):
"""
Buffers `k`, `v` have shape
`(batch_size, n_query_groups, max_seq_length, head_size)`.
"""
def __init__(
self,
k_shape: Tuple[int, int, int, int],
v_shape: Tuple[int, int, int, int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
is_sliding_window: bool = False,
sliding_window_size: Optional[int] = None,
) -> None:
super().__init__()
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
self.is_sliding_window = is_sliding_window
self.sliding_window_size = sliding_window_size
self.max_cache_len = k_shape[2]
def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Writes new values `k` and `v` into the cache at the positions specified
by `input_pos` along the sequence dimension (`max_seq_length`). The batch
size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch
size. Returns the full buffers, adjusted to the batch size `bs`.
Args:
input_pos: Position index, `(bs, T)` or `(T,)`
k: New values, `(bs, n_query_groups, T, head_size)`
v: New values, `(bs, n_query_groups, T, head_size)`
Returns:
k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)`
"""
# move the buffer to the activation dtype for when AMP is used
if self.k.dtype != k.dtype:
self.k = self.k.to(k.dtype)
if self.v.dtype != v.dtype:
self.v = self.v.to(v.dtype)
# update the cache
bs = k.size(0)
if self.is_sliding_window:
# Circular buffer for sliding window
cache_positions = input_pos % self.max_cache_len
k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k)
v = batched_index_copy_(self.v[:bs, ...], -2, cache_positions, v)
max_pos = input_pos.max().item()
if max_pos < self.max_cache_len:
k = k[:, :, : max_pos + 1, :]
v = v[:, :, : max_pos + 1, :]
else:
# Standard KV cache (global attention)
k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k)
v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v)
return k, v
def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.k)
torch.nn.init.zeros_(self.v)
def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor:
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
class RMSNorm(torch.nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
self.add_unit_offset = add_unit_offset
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
weight = (1 + self.weight) if self.add_unit_offset else self.weight
return (x_normed * weight.float()).to(dtype=dtype)
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
================================================
FILE: litgpt/parser_config.py
================================================
import sys
from pathlib import Path
from typing import List, Optional
from litgpt.utils import CLI
def parser_commands() -> List[str]:
return [
"download",
"chat",
"finetune",
"finetune_lora",
"finetune_full",
"finetune_adapter",
"finetune_adapter_v2",
"pretrain",
"generate",
"generate_full",
"generate_adapter",
"generate_adapter_v2",
"generate_sequentially",
"generate_speculatively",
"generate_tp",
"convert_to_litgpt",
"convert_from_litgpt",
"convert_pretrained_checkpoint",
"merge_lora",
"evaluate",
"serve",
]
def save_hyperparameters(
function: callable,
checkpoint_dir: Path,
known_commands: Optional[List[str]] = None,
) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser
# TODO: Make this more robust
# This hack strips away the subcommands from the top-level CLI
# to parse the file as if it was called as a script
if known_commands is None:
known_commands = parser_commands()
known_commands = [(c,) for c in known_commands]
for known_command in known_commands:
unwanted = slice(1, 1 + len(known_command))
if tuple(sys.argv[unwanted]) == known_command:
sys.argv[unwanted] = []
parser = capture_parser(lambda: CLI(function))
config = parser.parse_args()
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
================================================
FILE: litgpt/pretrain.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import math
import pprint
import time
import warnings
from dataclasses import asdict
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import lightning as L
import torch
import torch.nn as nn
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops
from torch.utils.data import DataLoader
from torchmetrics.aggregation import RunningMean
from typing_extensions import Literal
from litgpt import Tokenizer
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.config import name_to_config
from litgpt.constants import _TORCH_EQUAL_2_7, _TORCH_EQUAL_2_8
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.parser_config import save_hyperparameters
from litgpt.types import LoggerChoice
from litgpt.utils import (
CycleIterator,
capture_hparams,
check_nvlink_connectivity,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
find_resume_path,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
reset_parameters,
save_config,
)
def setup(
model_name: str,
model_config: Optional[Config] = None,
out_dir: Path = Path("out/pretrain"),
precision: Literal["bf16-true", "bf16-mixed", "32-true", None] = None,
initial_checkpoint_dir: Optional[Path] = None,
resume: Union[bool, Literal["auto"], Path] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=512,
micro_batch_size=4,
max_tokens=int(3e12), # 3 trillion
max_norm=1.0,
min_lr=4e-5,
lr_warmup_steps=2000,
tie_embeddings=False,
),
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
log: LogArgs = LogArgs(),
optimizer: Union[str, Dict] = "AdamW",
devices: Union[int, str] = "auto",
num_nodes: int = 1,
tokenizer_dir: Optional[Path] = None,
logger_name: LoggerChoice = "tensorboard",
seed: int = 42,
):
"""Pretrain a model.
Arguments:
model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Use "list" to list the supported models.
model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
``model_config``. Overrides the `model_name` if specified.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs//share.
precision: The precision to use for finetuning. Determines a compatible precision setting by default.
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
Useful for continued pretraining. Mutually exclusive with ``resume``.
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing
``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.
data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optimizer: An optimizer name (such as "AdamW") or config.
devices: How many devices/GPUs to use. Uses all GPUs by default.
num_nodes: How many nodes the code is being run on.
tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
module require this.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
if model_name == "list":
available_models = "\n".join(sorted(name_to_config))
print(f"Available values:\n{available_models}")
quit()
if initial_checkpoint_dir is not None:
initial_checkpoint_dir = extend_checkpoint_dir(initial_checkpoint_dir)
if tokenizer_dir is not None:
tokenizer_dir = extend_checkpoint_dir(tokenizer_dir)
if model_config is None:
# Support both model_name options: meta-llama/Meta-Llama-3-8B & Meta-Llama-3-8B
try:
model_config = Config.from_name(model_name)
except ValueError:
print(f"Model name {model_name} is not supported.\n")
available_models = "\n".join(sorted(name_to_config))
print(f"Available values:\n{available_models}")
quit()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
config = Config.from_name(model_name) if model_config is None else model_config
precision = precision or get_default_supported_precision(training=True)
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
# in case the dataset requires the Tokenizer
tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None
logger = choose_logger(
logger_name,
out_dir,
name=f"pretrain-{config.name}",
resume=bool(resume),
log_interval=train.log_interval,
log_args=asdict(log),
)
if devices * num_nodes > 1:
strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD")
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=[logger])
if torch.cuda.is_available() and devices > 1:
check_nvlink_connectivity(fabric)
fabric.launch()
fabric.print(pprint.pformat(hparams))
if logger_name in ("tensorboard", "wandb", "mlflow"):
fabric.logger.log_hyperparams(hparams)
main(
fabric=fabric,
devices=devices,
num_nodes=num_nodes,
seed=seed,
initial_checkpoint_dir=initial_checkpoint_dir,
resume=resume,
config=config,
data=data,
out_dir=out_dir,
tokenizer_dir=tokenizer_dir,
tokenizer=tokenizer,
train=train,
eval=eval,
optimizer=optimizer,
)
def main(
fabric: L.Fabric,
devices: int,
seed: int,
initial_checkpoint_dir: Optional[Path],
resume: Union[bool, Literal["auto"], Path],
config: Config,
data: DataModule,
out_dir: Path,
tokenizer_dir: Optional[Path],
tokenizer: Optional[Tokenizer],
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
num_nodes: int = 1,
) -> None:
validate_args(train, eval, initial_checkpoint_dir, resume)
if fabric.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd)
if train.tie_embeddings:
model.transformer.wte.weight = model.lm_head.weight
if train.max_seq_length:
model.max_seq_length = train.max_seq_length
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters: {num_parameters(model):,}")
model = torch.compile(model)
model = fabric.setup(model)
extra_kwargs = {"fused": fabric.device.type == "cuda"}
optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), **extra_kwargs)
optimizer = fabric.setup_optimizers(optimizer)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
if initial_checkpoint_dir:
fabric.load_raw(initial_checkpoint_dir / "lit_model.pth", model)
state = {
"model": model,
"optimizer": optimizer,
"train_dataloader": train_dataloader,
"iter_num": 0,
"step_count": 0,
}
resume = find_resume_path(resume, out_dir)
if resume:
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
train_time = time.perf_counter()
# work around PyTorch issue https://github.com/pytorch/pytorch/issues/152162
# which does not like the lazy initialization to be called in dynamo.
# TODO: Happens with PyTorch 2.7+
if (
(_TORCH_EQUAL_2_7 or _TORCH_EQUAL_2_8)
and (model._forward_module.__class__.__name__ == "OptimizedModule")
and (model._forward_module._orig_mod.__class__.__name__ == "FullyShardedDataParallel")
):
from torch.distributed.fsdp._runtime_utils import _root_pre_forward
_root_pre_forward(model._forward_module._orig_mod, model._forward_module._orig_mod, [], {})
fit(
fabric=fabric,
devices=devices,
num_nodes=num_nodes,
state=state,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
out_dir=out_dir,
tokenizer_dir=tokenizer_dir,
train=train,
eval=eval,
)
# Save final checkpoint
save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth")
total_tokens = state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size
# Print formatted output
separator = "-" * 40
fabric.print(separator)
fabric.print("| Performance")
fabric.print(f"| - Total tokens : {total_tokens:,}")
fabric.print(f"| - Training Time : {(time.perf_counter() - train_time):.2f} s")
fabric.print(f"| - Tok/sec : {total_tokens / train_time:.2f} tok/s")
fabric.print("| " + "-" * 40)
if fabric.device.type == "cuda":
memory_used = torch.cuda.max_memory_allocated() / 1e9
fabric.print("| Memory Usage")
fabric.print(f"| - Memory Used : {memory_used:.2f} GB")
fabric.print(separator)
def fit(
fabric: L.Fabric,
devices: int,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
out_dir: Path,
tokenizer_dir: Optional[Path],
train: TrainArgs,
eval: EvalArgs,
num_nodes: int = 1,
) -> None:
model = state["model"]
optimizer = state["optimizer"]
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)
val_loss = f"{val_loss:.3f}"
else:
fabric.print("Verifying settings ...")
validate(fabric, model, val_dataloader, max_iters=2, verbose=False) # sanity check
val_loss = "n/a"
throughput = ThroughputMonitor(fabric, window_size=5)
with torch.device("meta"):
meta_model = GPT(model.config)
x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length))
model_fwd = lambda: meta_model(x) # noqa: F821
model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0) # noqa: F821
measured_flops = measure_flops(meta_model, model_fwd, model_loss)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
max_tokens_per_device = train.max_tokens // fabric.world_size
tokens_per_iter = train.micro_batch_size * model.max_seq_length
max_iters = max_tokens_per_device // tokens_per_iter
log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes)
initial_iter = state["iter_num"]
train_iterator = CycleIterator(train_dataloader)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
total_t0 = time.perf_counter()
warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader)
for train_data in train_iterator:
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
state["iter_num"] += 1
iter_t0 = time.perf_counter()
input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long()
targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long()
is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices, num_nodes) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets)
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))
running_loss.update(loss.detach())
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=train.max_norm)
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1
if state["iter_num"] % log_iter_interval == 0:
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=(t1 - total_t0),
flops=(measured_flops * log_iter_interval),
batches=state["iter_num"],
samples=(state["iter_num"] * train.micro_batch_size),
lengths=(state["iter_num"] * train.micro_batch_size * model.max_seq_length),
)
metrics = {
"loss": loss,
"iter": state["iter_num"],
"step": state["step_count"],
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"remaining_time": (
(t1 - total_t0) / (state["iter_num"] - initial_iter) * (max_iters - state["iter_num"])
),
"tokens": state["iter_num"] * train.micro_batch_size * model.max_seq_length,
"total_tokens": (state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size),
"learning_rate": lr,
}
if isinstance(val_loss, float):
val_loss = f"{val_loss:.3f}"
fabric.print(
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
f" loss train: {metrics['loss']:.3f},"
f" val: {val_loss} |"
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
f"{' (step)' if not is_accumulating else ''}"
f" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}"
)
throughput_metrics = throughput.compute()
metrics.update(throughput_metrics)
fabric.log_dict(metrics, step=state["iter_num"] - 1)
if val_dataloader is not None and not is_accumulating and state["step_count"] % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)
val_loss = val_loss.item()
td = time.perf_counter() - t0
fabric.print(f"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms")
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=state["iter_num"] - 1)
fabric.barrier()
if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:
save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth")
# Final validation
if eval.final_validation:
val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=state["iter_num"])
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
@torch.no_grad()
def validate(
fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int, verbose: bool = True
) -> torch.Tensor:
fabric.barrier()
if verbose:
fabric.print("Validating ...")
model.eval()
losses = []
for k, batch in enumerate(val_dataloader):
if k >= max_iters:
break
input_ids = batch[:, 0 : model.max_seq_length].contiguous().long()
targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets)
losses.append(loss)
val_loss = torch.stack(losses).mean()
model.train()
fabric.barrier()
return val_loss
def get_dataloaders(
fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int
) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with linear warmup)
def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > max_iters, return min learning rate
if it > max_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
def initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None:
"""GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745)."""
# Adapted from https://github.com/jzhang38/TinyLlama
def init_weights(module, std):
nn.init.normal_(module.weight, mean=0.0, std=std)
if getattr(module, "bias", None) is not None:
nn.init.zeros_(module.bias)
for mod in model.modules():
if isinstance(mod, (nn.Embedding, nn.Linear)):
mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd))
# need a separate loop because `mod.proj` below is a `nn.Linear` too
for mod in model.modules():
if isinstance(mod, (LLaMAMLP, CausalSelfAttention)):
mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer))
if not isinstance(fabric.strategy, FSDPStrategy):
reset_parameters(model)
def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):
model = state["model"]
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)
def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:
issues = []
unsupported = [(train, ["epochs"]), (eval, ["max_new_tokens"])]
for args, names in unsupported:
for name in names:
if getattr(args, name) is not None:
issues.append(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
if train.max_steps is not None:
warnings.warn(
"`train.max_steps` is intended for profiling or debug runs only. "
"For full pretraining runs, prefer `train.max_tokens` or `train.max_time`.",
UserWarning,
)
required = [(train, ["max_tokens", "max_norm"])]
for args, names in required:
for name in names:
if getattr(args, name) is None:
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
if initial_checkpoint_dir and resume:
issues.append("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.")
if issues:
raise ValueError("\n".join(issues))
================================================
FILE: litgpt/prompts.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import importlib
import re
from abc import abstractmethod
from json import dumps
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import yaml
from litgpt.config import Config
if TYPE_CHECKING:
from litgpt import Tokenizer
class PromptStyle:
"""Base interface for prompt styles."""
@abstractmethod
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return prompt
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return ([tokenizer.eos_id],)
@classmethod
def from_name(cls, name: str) -> "PromptStyle":
return prompt_styles[name]()
@classmethod
def from_config(cls, config: Config) -> "PromptStyle":
return model_name_to_prompt_style(config.name)
class Default(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return prompt
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return ([tokenizer.eos_id],)
class Alpaca(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
if kwargs.get("input"):
sys_prompt = sys_prompt or (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
)
return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Input:\n{kwargs['input']}\n\n### Response:\n"
sys_prompt = sys_prompt or (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
)
return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n"
class FLAN(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
)
return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n"
class Longform(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
)
return f"{sys_prompt}### Instruction:\n{prompt}\n\n### Response:\n"
class StableLMAlpha(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or (
"# StableLM Tuned (Alpha version)\n- StableLM is a helpful and harmless open-source AI language"
" model developed by StabilityAI.\n- StableLM is excited to be able to help the user, but will refuse to do"
" anything that could be considered harmful to the user.\n- StableLM is more than just an information"
" source, StableLM is also able to write poetry, short stories, and make jokes.\n- StableLM will refuse to"
" participate in anything that could harm a human."
)
return f"<|SYSTEM|>{sys_prompt}<|USER|>{prompt}<|ASSISTANT|>"
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|SYSTEM|>")],
[tokenizer.token_to_id("<|ASSISTANT|>")],
[tokenizer.token_to_id("<|USER|>")],
)
class StableLMZephyr(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n"
class Falcon(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"{prompt}\nAnswer:"
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
# the model rarely emits the eos token and instead outputs newlines, but we cannot use them
# to stop or else things like code generation wouldn't work
[tokenizer.token_to_id("User"), tokenizer.token_to_id(":")],
[193, tokenizer.token_to_id("User")], # 193: '\n'
)
class Falcon3(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n"
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|endoftext|>")],
)
class Llama2FunctionCalling(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
# Has to be before the llama config
b_func, e_func = "", "\n\n"
b_inst, e_inst = "[INST]", "[/INST]"
b_sys, e_sys = "<>\n", "\n<>\n\n"
# This is an example for how to format functions for the model
function_metadata = {
"function": "search_bing",
"description": (
"Search the web for content on Bing. This allows users to search online/the internet/the web for"
" content."
),
"arguments": [{"name": "query", "type": "string", "description": "The search query string"}],
}
system_prompt = sys_prompt or (
"You are a helpful, respectful and honest assistant. Always answer as helpfully as"
"possible. Your only response should be JSON formatted functions"
)
# replace the curly braces with double curly braces to escape them
function_list = dumps(function_metadata).replace("{", "{{").replace("}", "}}")
return (
f"{b_func}{function_list.strip()}{e_func}{b_inst}{b_sys}{system_prompt.strip()}{e_sys}{prompt}{e_inst}\n\n"
)
class Llama2(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
b_inst, e_inst = "[INST]", "[/INST]"
b_sys, e_sys = "<>\n", "\n<>\n\n"
sys_prompt = sys_prompt or (
"You are a helpful, respectful and honest assistant. Always answer as helpfully as"
" possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist,"
" toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and"
" positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why"
" instead of answering something not correct. If you don't know the answer to a question, please don't"
" share false information."
)
return f"{b_inst} {b_sys}{sys_prompt}{e_sys} {prompt} {e_inst} "
class Llama3(PromptStyle):
def apply(
self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str
) -> str:
default_system_prompt = sys_prompt or "You are a helpful assistant."
# https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229
if isinstance(prompt, str):
return (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
f"{default_system_prompt}<|eot_id|>" # No newline
"<|start_header_id|>user<|end_header_id|>\n\n"
f"{prompt}<|eot_id|>" # No newline
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
elif isinstance(prompt, list):
def encode_header(role: str) -> List[str]:
return [f"<|start_header_id|>{role}<|end_header_id|>\n\n"]
def encode_message(message: Dict[str, str]) -> List[str]:
tokens = encode_header(message["role"])
# NOTE: Meta stripped this. I'm not sure I agree, but who am I to argue?
tokens.append(message["content"].strip())
tokens.append("<|eot_id|>")
return tokens
def has_system_prompt(messages: List[Dict[str, str]]) -> bool:
return messages[0].get("role", "") == "system" if len(messages) else False
tokens = ["<|begin_of_text|>"]
if not has_system_prompt(prompt):
tokens.extend(encode_message({"role": "system", "content": default_system_prompt}))
for i, message in enumerate(prompt):
if i != 0 and message["role"] == "system":
raise ValueError("'system' role is only allowed at the beginning of the conversation list.")
if message["role"] not in ["assistant", "user", "system"]:
raise ValueError(
f"Unknown role: '{message['role']}'. Supported roles are 'assistant', 'user', and 'system'."
)
tokens.extend(encode_message(message))
tokens.extend(encode_header("assistant"))
return "".join(tokens)
else:
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|eot_id|>")],
)
class R1Base(PromptStyle):
def apply(
self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str
) -> str:
default_system_prompt = sys_prompt or ""
bos_token = "<|begin▁of▁sentence|>"
eos_token = ""
if isinstance(prompt, str):
return f"{default_system_prompt}<|User|>{prompt}<|Assistant|>" # Prepares for assistant response
elif isinstance(prompt, list):
def encode_message(message: Dict[str, str]) -> str:
role = message["role"]
content = message["content"].strip()
if role == "system":
return content # System prompt is prepended at the start
elif role == "user":
return f"<|User|>{content}"
elif role == "assistant":
return f"<|Assistant|>{content}{eos_token}"
else:
raise ValueError(f"Unknown role: '{role}'. Supported roles are 'assistant', 'user', and 'system'.")
# Extract system prompt (if any)
system_prompt = ""
if prompt[0].get("role") == "system":
system_prompt = prompt[0]["content"]
prompt = prompt[1:] # Remove system message from the list
# Construct the formatted prompt
formatted_prompt = system_prompt
for message in prompt:
formatted_prompt += encode_message(message)
formatted_prompt += "<|Assistant|>" # Prepares for assistant response
return formatted_prompt
else:
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|end▁of▁sentence|>")],
)
class FreeWilly2(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or "This is a system prompt, please behave and help the user."
return f"### System:\n{sys_prompt}\n\n### User:\n{prompt}\n\n### Assistant:\n"
class Platypus(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"### Instruction:\n\n{prompt}\n\n### Response:\n"
class StableCode(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"###Instruction\n{prompt}###Response\n"
class CodeLlama(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
# for CodeLLama, we don't set a default system prompt, but it is supported:
# https://huggingface.co/blog/codellama#conversational-instructions
# Mistral does not: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
b_inst, e_inst = "[INST]", "[/INST]"
if sys_prompt:
b_sys, e_sys = "<>\n", "\n<>\n\n"
return f"{b_inst} {b_sys}{sys_prompt}{e_sys}{prompt} {e_inst}"
return f"{b_inst} {prompt} {e_inst}"
class Phi1(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"{prompt}\n\nAnswer:"
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")],
[198, tokenizer.token_to_id("Answer"), tokenizer.token_to_id(":")],
# the model rarely emits the eos token and instead outputs newlines, but we cannot use them
# to stop or else things like code generation wouldn't work
# [198, 198], # '\n', '\n'
)
class Phi2(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"Instruct: {prompt}\nOutput:"
class Phi3(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or "You are a helpful assistant."
return f"<|system|>\n{sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
class Phi4(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
res = ""
if sys_prompt:
res += f"<|im_start|>system<|im_sep|>{sys_prompt}<|im_end|>"
res += f"<|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>"
return res
class Phi4Reasoning(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = (
sys_prompt
or "You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: {Thought section} {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:"
)
return f"<|im_start>system<|im_sep|>{sys_prompt}<|im_end|><|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>"
class Phi4Mini(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
res = ""
if sys_prompt:
res += f"<|system|>{sys_prompt}<|end|>"
res += f"<|user|>{prompt}<|end|><|assistant|>"
return res
class Phi4MiniReasoning(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or "Your name is Phi, an AI math expert developed by Microsoft."
return f"<|system|>{sys_prompt}<|end|><|user|>{prompt}<|end|><|assistant|>"
class TinyLlama(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or "You are a friendly chatbot who always gives helpful, detailed, and polite answers."
return f"<|system|>\n{sys_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n"
class Gemma(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"user\n{prompt}\nmodel\n"
class OLMo(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"
class ChatML(PromptStyle):
def __init__(self, system_message: Optional[str] = None):
self.system_message = system_message
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
sys_prompt = sys_prompt or self.system_message
return (
f"<|im_start|>system\n{sys_prompt}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
class Qwen2_5(ChatML):
def __init__(self):
super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")
class Qwen2_5_Math(ChatML):
def __init__(self):
super().__init__("Please reason step by step, and put your final answer within \\boxed{}.")
class QwQ(ChatML):
def __init__(self):
super().__init__(
"You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
)
class Qwen3(ChatML):
def __init__(self):
super().__init__()
class SmolLM2(ChatML):
def __init__(self):
super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face")
class Salamandra(ChatML):
def __init__(self):
super().__init__(
"I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit."
)
# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
"alpaca": Alpaca,
"flan": FLAN,
"longform": Longform,
# Model-specific prompt styles
"stablelm-alpha": StableLMAlpha,
"stablelm-zephyr": StableLMZephyr,
"falcon": Falcon,
"llama2-function-calling": Llama2FunctionCalling,
"llama2": Llama2,
"freewilly2": FreeWilly2,
"platypus": Platypus,
"stablecode": StableCode,
"codellama": CodeLlama,
"phi-1": Phi1,
"phi-2": Phi2,
"phi-3": Phi3,
"phi-4": Phi4,
"phi-4-reasoning": Phi4Reasoning,
"phi-4-mini": Phi4Mini,
"phi-4-mini-reasoning": Phi4MiniReasoning,
"tinyllama": TinyLlama,
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"qwen3": Qwen3,
"smollm2": SmolLM2,
"salamandra": Salamandra,
}
def model_name_to_prompt_style(model_name: str) -> PromptStyle:
if re.search(r"stablelm-tuned-alpha", model_name):
return StableLMAlpha()
if re.search(r"stablelm-zephyr-3b", model_name):
return StableLMZephyr()
if re.search("stablecode-instruct", model_name):
return StableCode()
if re.search(r"Falcon3.*-Instruct", model_name):
return Falcon3()
if re.search(r"falcon.*-instruct", model_name):
return Falcon()
if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name):
return Llama2FunctionCalling()
if re.search("Llama-2.*-chat", model_name):
return Llama2()
if re.search("Llama-3.*-Instruct", model_name):
return Llama3()
if re.search("Llama-3.*-Instruct-*", model_name):
return Llama3()
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
return Llama3()
if re.search("R1", model_name):
return R1Base()
if re.search("FreeWilly2", model_name):
return FreeWilly2()
if re.search("Platypus", model_name):
return Platypus()
if re.search("CodeLlama|Mi[sx]tral.*Instruct", model_name):
return CodeLlama()
if re.search("phi-1", model_name):
return Phi1()
if re.search("phi-2", model_name):
return Phi2()
if re.search("Phi-3", model_name):
return Phi3()
if re.search("Phi-4-reasoning", model_name):
return Phi4Reasoning()
if re.search("Phi-4-mini-reasoning", model_name):
return Phi4MiniReasoning()
if re.search("Phi-4-mini", model_name):
return Phi4Mini()
if re.search("phi-4", model_name):
return Phi4()
if re.search(r"tiny-llama.*chat", model_name):
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
return Gemma()
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
if re.search(r"Qwen2\.5-Math-.*", model_name):
return Qwen2_5_Math()
if re.search(r"Qwen2\.5-.*", model_name):
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return QwQ()
if re.search(r"Qwen3-.*", model_name):
return Qwen3()
if re.search(r"SmolLM2.*-Instruct", model_name):
return SmolLM2()
if re.search(r"salamandra-.*-instruct", model_name):
return Salamandra()
return Default()
def save_prompt_style(style: Union[str, PromptStyle], checkpoint_dir: Path) -> None:
style = PromptStyle.from_name(style) if isinstance(style, str) else style
cls = type(style)
# Allow saving the full module path for user-defined prompt classes
config = {"class_path": f"{cls.__module__}.{cls.__name__}"}
with open(checkpoint_dir / "prompt_style.yaml", "w", encoding="utf-8") as file:
yaml.dump(config, file)
def load_prompt_style(checkpoint_dir: Path) -> PromptStyle:
with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file:
config = yaml.safe_load(file)
# Support loading the full module path for user-defined prompt classes
full_module_path, cls_name = config["class_path"].rsplit(".", 1)
module = importlib.import_module(full_module_path)
cls = getattr(module, cls_name)
return cls()
def has_prompt_style(checkpoint_dir: Path) -> bool:
return (checkpoint_dir / "prompt_style.yaml").is_file()
================================================
FILE: litgpt/scripts/__init__.py
================================================
================================================
FILE: litgpt/scripts/convert_hf_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import gc
import json
import os
import re
import warnings
from collections import defaultdict
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
from safetensors.torch import load_file as load_safetensors
from tqdm import tqdm
from litgpt.config import Config
from litgpt.utils import (
extend_checkpoint_dir,
incremental_save,
lazy_load,
save_config,
)
def copy_weights_gpt_neox(
config: Config,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"gpt_neox.embed_in.weight": "transformer.wte.weight",
"gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias",
"gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight",
"gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias",
"gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None,
"gpt_neox.layers.{}.attention.bias": None,
"gpt_neox.layers.{}.attention.masked_bias": None,
"gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias",
"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias",
"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
"gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias",
"gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight",
"embed_out.weight": "lm_head.weight",
}
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights))
for from_name, param in hf_weights.items():
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template]
if to_name is None:
continue
to_name = to_name.format(layer_idx)
param = load_param(param, from_name, dtype, verbose=debug_mode)
if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")):
# Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_falcon(
config: Config,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"transformer.word_embeddings.weight": "transformer.wte.weight",
"transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight",
"transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
"transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
"transformer.ln_f.bias": "transformer.ln_f.bias",
"transformer.ln_f.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
# the original model definition is different for each size
if "7b" in config.name:
weight_map.update(
{
"transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
}
)
elif "40b" in config.name or "180B" in config.name:
weight_map.update(
{
"transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight",
"transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias",
"transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight",
}
)
else:
raise NotImplementedError
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights))
for from_name, param in hf_weights.items():
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template].format(layer_idx)
param = load_param(param, from_name, dtype, verbose=debug_mode)
if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")):
# Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_hf_llama(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name == "LLaMAMoE":
weight_map.update(
{
"model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{}.mlp.gate.weight",
"model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight",
"model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight",
"model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{}.mlp.experts.{}.proj.weight",
}
)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_gemma_2(
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_gemma_3(
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
config: Optional[Config] = None,
) -> None:
GEMMA3_LANGUAGE_MODEL_PREFIX = (
"model.language_model"
if any(k.startswith("model.language_model") for k in hf_weights)
else "language_model.model"
)
GEMMA3_VISION_MODEL_PREFIX = (
"model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower"
)
GEMMA3_MM_PROJECTOR_PREFIX = (
"model.multi_modal_projector"
if any(k.startswith("model.multi_modal_projector") for k in hf_weights)
else "multi_modal_projector"
)
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
}
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
# gemma3 4b+ are multimodel models, but we are only loading the text weights
is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights)
if is_multimodal:
warnings.warn("For Gemma3 models only the text component is supported.")
new_weight_map = dict()
prefix = "model"
for k, v in weight_map.items():
if k.startswith(prefix):
k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :]
new_weight_map[k] = v
weight_map = new_weight_map
for from_name, param in hf_weights.items():
if from_name.startswith(GEMMA3_VISION_MODEL_PREFIX) or from_name.startswith(GEMMA3_MM_PROJECTOR_PREFIX):
continue
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map.get(name_template)
param = load_param(param, from_name, dtype, verbose=debug_mode)
# in multimodal models, the text weights are the first part of the weights
if is_multimodal and to_name == "transformer.wte.weight" and config is not None:
param = param[: config.vocab_size]
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_phi(
config: Config,
qkv_weights: dict,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
if any(layer_name.startswith(("layers.", "transformer.")) for layer_name in hf_weights):
raise ValueError(
"You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'"
)
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.q_proj.bias": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.k_proj.bias": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.v_proj.bias": None,
"model.layers.{}.self_attn.dense.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.dense.bias": "transformer.h.{}.attn.proj.bias",
"model.layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
"model.layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
"model.layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
"model.layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
"model.final_layernorm.weight": "transformer.ln_f.weight",
"model.final_layernorm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}
if config.name.startswith(("Phi-3", "phi-4", "Phi-4")):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
}
)
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, layer_idx = layer_template(from_name)
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
elif from_name.endswith("gate_up_proj.weight"):
weight = load_param(param, f"layer {layer_idx} gate_up_proj", dtype, verbose=debug_mode)
fc_1, fc_2 = weight.chunk(2, dim=0)
state_dict[f"transformer.h.{layer_idx}.mlp.fc_1.weight"] = fc_1
state_dict[f"transformer.h.{layer_idx}.mlp.fc_2.weight"] = fc_2
continue
to_name = weight_map[name_template]
if to_name is None:
continue
to_name = to_name.format(layer_idx)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict and config.name.startswith("Phi-4"):
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_qwen_2_5(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.q_proj.bias": None,
"model.layers.{}.self_attn.k_proj.bias": None,
"model.layers.{}.self_attn.v_proj.bias": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_olmo2(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.post_attention_norm.bias",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def copy_weights_qwen_3(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name == "LLaMAMoE":
weight_map.update(
{
"model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight",
"model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight",
"model.layers.{}.mlp.experts.{}.down_proj.weight": "transformer.h.{}.mlp.experts.{}.proj.weight",
"model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight",
}
)
elif config.mlp_class_name == "LLaMAMLP":
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError
if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if progress_per_file is not None:
pbar.update(progress_per_file)
if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]
if progress_per_file is not None:
pbar.update(progress_per_file)
def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...]
"""
q_per_kv = config.n_head // config.n_query_groups
qs = []
ks = []
vs = []
for chunk in torch.chunk(param, config.n_query_groups):
split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size])
qs.append(split[0])
ks.append(split[1])
vs.append(split[2])
q = torch.cat(qs)
k = torch.cat(ks)
v = torch.cat(vs)
return torch.cat((q, k, v))
def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]:
pattern = r"\.(\d+)\."
if not (search_res := re.findall(pattern, layer_name)):
return layer_name, -1
layer_name_template = re.sub(pattern, ".{}.", layer_name, count=num_matches)
return layer_name_template, *(int(x) for x in search_res[:num_matches])
def load_param(
param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool = False
) -> torch.Tensor:
if hasattr(param, "_load_tensor"):
# support tensors loaded via `lazy_load()`
if verbose:
print(f"Loading {name!r} into RAM")
param = param._load_tensor()
if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:
if verbose:
print(f"Converting {name!r} from {param.dtype} to {dtype}")
param = param.to(dtype)
return param
@torch.inference_mode()
def convert_hf_checkpoint(
checkpoint_dir: Path,
*,
model_name: Optional[str] = None,
dtype: Optional[str] = None,
debug_mode: Optional[bool] = False,
) -> None:
"""
Convert a Hugging Face Transformers checkpoint into a LitGPT compatible checkpoint.
Arguments:
checkpoint_dir: Where to save the downloaded files.
model_name: The existing config name to load. This is useful to download alternative weights of existing
architectures.
dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the
dtype they are downloaded in.
debug_mode: Prints the individual layers being loaded instead of a progress bar, which can be useful when
developing and adding new models to LitGPT.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
if model_name is None:
model_name = checkpoint_dir.name
if dtype is not None:
dtype = getattr(torch, dtype)
config = Config.from_name(model_name)
save_config(config, checkpoint_dir)
if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, config)
elif model_name.lower().startswith("gemma-2"):
qkv_weights = {}
copy_fn = partial(copy_weights_gemma_2, qkv_weights)
elif model_name.lower().startswith("gemma-3"):
qkv_weights = {}
copy_fn = partial(copy_weights_gemma_3, qkv_weights, config=config)
elif model_name.lower().startswith("phi"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
elif model_name.lower().startswith(("qwen2.5", "qwq")):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
elif model_name.lower().startswith("olmo-2-"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_olmo2, config, qkv_weights)
elif model_name.lower().startswith("qwen3"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_3, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
else:
copy_fn = partial(copy_weights_gpt_neox, config)
# initialize a new empty state dict to hold our new weights
sd = {}
# Load the json file containing weight mapping
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
model_safetensor_map_json_path = checkpoint_dir / "model.safetensors.index.json"
if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file
with open(pytorch_bin_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
elif model_safetensor_map_json_path.is_file():
with open(model_safetensor_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
bin_files = set(checkpoint_dir.glob("*.bin")) | set(checkpoint_dir.glob("*.safetensors"))
# some checkpoints serialize the training arguments
bin_files = {f for f in bin_files if f.name != "training_args.bin"}
if not bin_files:
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files")
with incremental_save(checkpoint_dir / "lit_model.pth") as saver:
# for checkpoints that split the QKV across several files, we need to keep all the bin files
# open, so we use `ExitStack` to close them all together at the end
if not debug_mode:
# Using tqdm progress bar when not in debug mode
total_size = max(1, sum(os.path.getsize(bin_file) for bin_file in bin_files))
total_progress = 100
with tqdm(
total=total_progress,
desc="Initializing",
bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}",
) as pbar:
for bin_file in sorted(bin_files):
pbar.set_description(f"Loading weights: {bin_file.name}")
current_file_size = os.path.getsize(bin_file)
progress_per_file = (current_file_size / total_size) * total_progress
hf_weights = (
load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
)
copy_fn(
sd,
hf_weights,
saver=saver,
dtype=dtype,
pbar=pbar,
progress_per_file=progress_per_file,
debug_mode=debug_mode,
)
gc.collect()
if pbar.n < total_progress:
pbar.update(total_progress - pbar.n)
pbar.close()
else:
# Handling files without progress bar in debug mode
for bin_file in sorted(bin_files):
hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode)
print(f"Saving converted checkpoint to {checkpoint_dir}")
saver.save(sd)
================================================
FILE: litgpt/scripts/convert_lit_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import gc
from collections import defaultdict
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Dict, Optional, Union
import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
from litgpt import Config
from litgpt.scripts.convert_hf_checkpoint import layer_template, load_param
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load
def copy_weights_falcon(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "transformer.word_embeddings.weight",
"transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight",
"transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight",
"transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight",
"transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight",
"transformer.ln_f.bias": "transformer.ln_f.bias",
"transformer.ln_f.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
# the original model definition is different for each size
if "7b" in config.name:
weight_map.update(
{
"transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias",
"transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight",
}
)
elif "40b" in config.name or "180B" in config.name:
weight_map.update(
{
"transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias",
"transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight",
"transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias",
"transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight",
}
)
else:
raise NotImplementedError
for from_name, param in lit_weights.items():
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template].format(layer_idx)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
# Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_gpt_neox(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "gpt_neox.embed_in.weight",
"transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias",
"transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight",
"transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias",
"transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight",
"transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias",
"transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight",
"transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias",
"transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias",
"transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight",
"transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias",
"transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight",
"transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias",
"transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight",
"lm_head.weight": "embed_out.weight",
}
for from_name, param in lit_weights.items():
name_template, layer_idx = layer_template(from_name)
to_name = weight_map[name_template].format(layer_idx)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
# Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]
param = qkv_reassemble(param, config)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_llama(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias",
"transformer.ln_f.weight": "model.norm.weight",
"transformer.ln_f.bias": "model.norm.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name == "LLaMAMoE":
weight_map.update(
{
"transformer.h.{}.mlp.gate.weight": "model.layers.{}.block_sparse_moe.gate.weight",
"transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.block_sparse_moe.experts.{}.w1.weight",
"transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.block_sparse_moe.experts.{}.w3.weight",
"transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.block_sparse_moe.experts.{}.w2.weight",
}
)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_gemma_2(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = True,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.post_attention_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.pre_feedforward_layernorm.weight",
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
}
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_gemma_3(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = True,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.post_attention_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.pre_feedforward_layernorm.weight",
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
}
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_phi(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.dense.weight",
"transformer.h.{}.attn.proj.bias": "model.layers.{}.self_attn.dense.bias",
"transformer.h.{}.mlp.fc.weight": "model.layers.{}.mlp.fc1.weight",
"transformer.h.{}.mlp.fc.bias": "model.layers.{}.mlp.fc1.bias",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.fc2.weight",
"transformer.h.{}.mlp.proj.bias": "model.layers.{}.mlp.fc2.bias",
"transformer.ln_f.weight": "model.final_layernorm.weight",
"transformer.ln_f.bias": "model.final_layernorm.bias",
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}
if config.name.lower().startswith(("phi-3", "phi-4")):
weight_map.update(
{
"transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.ln_f.weight": "model.norm.weight",
}
)
gate_up_proj_weights = defaultdict(dict)
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and config.name.startswith("Phi-4"):
continue
name_template, layer_idx = layer_template(from_name)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
if config.name.lower().startswith(("phi-3", "phi-4")):
to_names = (weight_map[name_template].format(layer_idx),)
params = (param,)
else:
weight_type = from_name.split(".")[-1] # weight or bias
to_names = (
f"model.layers.{{}}.self_attn.q_proj.{weight_type}".format(layer_idx),
f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx),
f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
elif from_name.endswith((".fc_1.weight", ".fc_2.weight")):
weight = load_param(param, from_name, None)
weight_name = from_name.split(".")[-2]
gate_up_proj_weights[layer_idx][weight_name] = weight
else:
to_names = (weight_map[name_template].format(layer_idx),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if config.name.lower().startswith(("phi-3", "phi-4")):
for layer_idx in list(gate_up_proj_weights):
fc_1_weight = gate_up_proj_weights[layer_idx]["fc_1"]
fc_2_weight = gate_up_proj_weights[layer_idx]["fc_2"]
weight = torch.concat([fc_1_weight, fc_2_weight], dim=0)
layer_name = f"model.layers.{layer_idx}.mlp.gate_up_proj.weight"
state_dict[layer_name] = weight
del gate_up_proj_weights[layer_idx]
def copy_weights_qwen_2_5(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.ln_f.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
}
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
weight_type = from_name.split(".")[-1] # weight or bias
to_names = (
"model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_olmo2(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias",
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"transformer.ln_f.bias": "model.norm.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_qwen_3(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name == "LLaMAMoE":
weight_map.update(
{
"transformer.h.{}.mlp.gate.weight": "model.layers.{}.mlp.gate.weight",
"transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.mlp.experts.{}.gate_proj.weight",
"transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.mlp.experts.{}.up_proj.weight",
"transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.mlp.experts.{}.down_proj.weight",
}
)
elif config.mlp_class_name == "LLaMAMLP":
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError
for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
weight_type = from_name.split(".")[-1] # weight or bias
to_names = (
"model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)
for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
"""
q, k, v = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
qs = q.split(config.n_head // config.n_query_groups * config.head_size)
ks = k.split(config.head_size)
vs = v.split(config.head_size)
interleaved = [t for group in zip(qs, ks, vs) for t in group]
return torch.cat(interleaved)
def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:
if any("lora" in wn for wn in lit_weights):
raise ValueError("Checkpoints with LoRA weights cannot be converted. Call `scripts/merge_lora.py` first.")
if any("adapter" in wn or "gating_factor" in wn for wn in lit_weights):
raise NotImplementedError("Converting adapter models is not supported.")
@torch.inference_mode()
def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
"""Convert a LitGPT trained checkpoint into a Hugging Face Transformers checkpoint."""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
config = Config.from_file(checkpoint_dir / "model_config.yaml")
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "model.pth"
if "falcon" in config.name:
copy_fn = partial(copy_weights_falcon, config)
elif config.name.startswith("Gemma-2"):
copy_fn = partial(copy_weights_gemma_2, config)
elif config.name.startswith("Gemma-3"):
copy_fn = partial(copy_weights_gemma_3, config)
elif config.name.lower().startswith("phi"):
copy_fn = partial(copy_weights_phi, config)
elif config.name.lower().startswith(("qwen2.5", "qwq")):
copy_fn = partial(copy_weights_qwen_2_5, config)
elif config.name.lower().startswith("olmo-2-"):
copy_fn = partial(copy_weights_olmo2, config)
elif config.name.lower().startswith("qwen3"):
copy_fn = partial(copy_weights_qwen_3, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
else:
copy_fn = partial(copy_weights_gpt_neox, config)
# initialize a new empty state dict to hold our new weights
sd = {}
with incremental_save(output_path) as saver:
lit_weights = lazy_load(checkpoint_dir / "lit_model.pth")
lit_weights = lit_weights.get("model", lit_weights)
check_conversion_supported(lit_weights)
copy_fn(sd, lit_weights, saver=saver)
gc.collect()
saver.save(sd)
================================================
FILE: litgpt/scripts/convert_pretrained_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from pathlib import Path
from pprint import pprint
import torch
from litgpt.utils import copy_config_files, extend_checkpoint_dir, incremental_save
@torch.inference_mode()
def convert_pretrained_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
"""Convert a checkpoint after pretraining.
The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training
is finished. This script will export the state-dict of the model and place it in the chosen output folder,
which then can be loaded by other scripts for inference, evaluation, etc.
Args:
checkpoint_dir: Path to a checkpoint directory produced by ``litgpt.pretrain``.
output_dir: The output folder where the converted state-dict file and config files will be saved to.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())
if output_dir.is_dir() and output_dir.glob("*"):
raise FileExistsError(
f"The output folder exists and is not empty: {str(output_dir)}."
" Please delete it first or choose a different name."
)
output_dir.mkdir(parents=True)
checkpoint_file = checkpoint_dir / "lit_model.pth"
output_checkpoint_file = output_dir / "lit_model.pth"
# TODO: Consolidate sharded checkpoint if applicable
# Extract the model state dict and save to output folder
with incremental_save(output_checkpoint_file) as saver:
print("Processing", checkpoint_file)
full_checkpoint = torch.load(str(checkpoint_file), mmap=True)
loaded_state_dict = full_checkpoint["model"]
converted_state_dict = {}
for param_name, param in loaded_state_dict.items():
saver.store_early(param)
# remove prefix for compiled model (if any)
param_name = param_name.replace("_orig_mod.", "")
converted_state_dict[param_name] = param
print(f"Saving converted checkpoint to {str(output_checkpoint_file)}.")
saver.save(converted_state_dict)
copy_config_files(checkpoint_dir, output_dir)
================================================
FILE: litgpt/scripts/download.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import importlib.util
import os
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional, Tuple
from litgpt.config import configs
from litgpt.constants import _HF_TRANSFER_AVAILABLE, _SAFETENSORS_AVAILABLE
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint
def download_from_hub(
repo_id: str,
access_token: Optional[str] = os.getenv("HF_TOKEN"),
tokenizer_only: bool = False,
convert_checkpoint: bool = True,
dtype: Optional[str] = None,
checkpoint_dir: Path = Path("checkpoints"),
model_name: Optional[str] = None,
) -> None:
"""Download weights or tokenizer data from the Hugging Face Hub.
Arguments:
repo_id: The repository ID in the format ``org/name`` or ``user/name`` as shown in Hugging Face.
If "list" is provided as input, a list of the currently supported models in LitGPT and quits.
access_token: Optional API token to access models with restrictions.
tokenizer_only: Whether to download only the tokenizer files.
convert_checkpoint: Whether to convert the checkpoint files to the LitGPT format after downloading.
dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the
dtype they are downloaded in.
checkpoint_dir: Where to save the downloaded files.
model_name: The existing config name to use for this repo_id. This is useful to download alternative weights of
existing architectures.
"""
options = [f"{config['hf_config']['org']}/{config['hf_config']['name']}" for config in configs]
if repo_id == "list":
print("Please specify --repo_id . Available values:")
print("\n".join(sorted(options, key=lambda x: x.lower())))
return
if model_name is None and repo_id not in options:
print(
f"Unsupported `repo_id`: {repo_id}."
"\nIf you are trying to download alternative "
"weights for a supported model, please specify the corresponding model via the `--model_name` option, "
"for example, `litgpt download NousResearch/Hermes-2-Pro-Llama-3-8B --model_name Llama-3-8B`."
"\nAlternatively, please choose a valid `repo_id` from the list of supported models, which can be obtained via "
"`litgpt download list`."
)
return
from huggingface_hub import snapshot_download
if importlib.util.find_spec("hf_transfer") is None:
print(
"It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`"
)
download_files = ["tokenizer*", "generation_config.json", "config.json"]
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
if bins:
# covers `.bin` files and `.bin.index.json`
download_files.append("*.bin*")
elif safetensors:
if not _SAFETENSORS_AVAILABLE:
raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))
download_files.append("*.safetensors*")
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")
import huggingface_hub._snapshot_download as download
import huggingface_hub.constants as constants
previous = constants.HF_HUB_ENABLE_HF_TRANSFER
if _HF_TRANSFER_AVAILABLE and not previous:
print("Setting HF_HUB_ENABLE_HF_TRANSFER=1")
constants.HF_HUB_ENABLE_HF_TRANSFER = True
download.HF_HUB_ENABLE_HF_TRANSFER = True
directory = checkpoint_dir / repo_id
with gated_repo_catcher(repo_id, access_token):
snapshot_download(
repo_id,
local_dir=directory,
allow_patterns=download_files,
token=access_token,
)
constants.HF_HUB_ENABLE_HF_TRANSFER = previous
download.HF_HUB_ENABLE_HF_TRANSFER = previous
if convert_checkpoint and not tokenizer_only:
print("Converting checkpoint files to LitGPT format.")
convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name)
def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]:
from huggingface_hub import repo_info
from huggingface_hub.utils import filter_repo_objects
with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors
@contextmanager
def gated_repo_catcher(repo_id: str, access_token: Optional[str]):
try:
yield
except OSError as e:
err_msg = str(e)
if "Repository Not Found" in err_msg:
raise ValueError(
f"Repository at https://huggingface.co/api/models/{repo_id} not found."
" Please make sure you specified the correct `repo_id`."
) from None
elif "gated repo" in err_msg:
if not access_token:
raise ValueError(
f"https://huggingface.co/{repo_id} requires authentication, please set the `HF_TOKEN=your_token`"
" environment variable or pass `--access_token=your_token`. You can find your token by visiting"
" https://huggingface.co/settings/tokens."
) from None
else:
raise ValueError(
f"https://huggingface.co/{repo_id} requires authentication. The access token provided by `HF_TOKEN=your_token`"
" environment variable or `--access_token=your_token` may not have sufficient access rights. Please"
f" visit https://huggingface.co/{repo_id} for more information."
) from None
raise e from None
================================================
FILE: litgpt/scripts/merge_lora.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""This script merges the LoRA weights with the base model"""
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, Optional, Tuple
import lightning as L
import torch
import yaml
from litgpt.lora import GPT, Config, lora_filter, merge_lora_weights
from litgpt.utils import check_valid_checkpoint_dir, extend_checkpoint_dir
def merge_lora(
checkpoint_dir: Path, pretrained_checkpoint_dir: Optional[Path] = None, precision: Optional[str] = None
) -> None:
"""Merges the LoRA weights with the base model.
See ``litgpt finetune lora``.
Creates a new ``lit_model.pth`` file by merging the LoRA weights (``lit_model.pth.lora``)
with the original checkpoint weights.
Arguments:
checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of
``litgpt finetune lora``.
pretrained_checkpoint_dir: Optional path to the checkpoint directory with the weights of the base model
corresponding to the LoRA checkpoint. By default, this will automatically be inferred from the metadata
in the given `checkpoint_dir` directory. Only set this if the base model's checkpoint directory
has moved or was renamed.
precision: Optional precision setting to instantiate the model weights in. By default, this will
automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
if pretrained_checkpoint_dir is not None:
pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir)
pprint(locals())
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth.lora")
if pretrained_checkpoint_dir is not None:
check_valid_checkpoint_dir(pretrained_checkpoint_dir)
if (checkpoint_dir / "lit_model.pth").is_file():
print("LoRA weights have already been merged in this checkpoint.")
return
lora_params, meta_pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir)
precision = precision if precision is not None else lora_precision
if pretrained_checkpoint_dir is None:
pretrained_checkpoint_dir = meta_pretrained_checkpoint_dir
pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir)
fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu")
config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params)
with fabric.init_module(), torch.device("meta"):
model = GPT(config)
# we don't care about these to perform merging
model.cos = None
model.sin = None
lora_path = checkpoint_dir / "lit_model.pth.lora"
pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True)
lora_checkpoint = torch.load(str(lora_path), mmap=True)
lora_checkpoint = lora_checkpoint.get("model", lora_checkpoint)
# Merge LoRA weights into the base model
pretrained_checkpoint.update(lora_checkpoint)
model.load_state_dict(pretrained_checkpoint, assign=True)
# since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype
lora_dtype = next(iter(lora_checkpoint.values())).dtype
model.to(dtype=lora_dtype, device="cpu")
merge_lora_weights(model)
# Remove LoRA parameters and the LoRA linear substring
state_dict = {k.replace("linear.", ""): v for k, v in model.state_dict().items() if not lora_filter(k, v)}
save_path = checkpoint_dir / "lit_model.pth"
torch.save(state_dict, save_path)
fabric.print(f"Saved merged weights to {str(checkpoint_dir / 'lit_model.pth')!r}")
def load_lora_metadata(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path, Optional[str]]:
hparams_file = checkpoint_dir / "hyperparameters.yaml"
if not hparams_file.is_file():
raise FileNotFoundError(
f"The path {str(hparams_file)!r} is not a valid checkpoint directory. It is missing a"
f" `hyperparameters.yaml` file. Please point to the checkpoint directory that was produced by"
f" the `litgpt/finetune/lora.py` script."
)
with open(hparams_file, encoding="utf-8") as file:
hparams = yaml.safe_load(file)
lora_params = {k: v for k, v in hparams.items() if k.startswith("lora_")}
pretrained_checkpoint_dir = Path(hparams["checkpoint_dir"])
precision = hparams.get("precision")
return lora_params, pretrained_checkpoint_dir, precision
================================================
FILE: litgpt/tokenizer.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Iterable, Iterator, Optional, Union
import torch
from litgpt.utils import fix_and_load_json
class Tokenizer:
def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}")
self.model_name = checkpoint_dir.stem
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
self.bos_id = None
self.eos_id = None
# some checkpoints have both files, `.json` takes precedence
if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
from tokenizers import Tokenizer as HFTokenizer
self.processor = HFTokenizer.from_file(str(vocabulary_path))
self.backend = "huggingface"
if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
bos_token = config.get("bos_token")
eos_token = config.get("eos_token")
if bos_token is not None and isinstance(bos_token, dict):
bos_token = bos_token.get("content")
if eos_token is not None and isinstance(eos_token, dict):
eos_token = eos_token.get("content")
self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
try:
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
except json.JSONDecodeError: # Some files like the Llama 3.2 one have bugs
with open(special_tokens_path, encoding="utf-8") as fp:
json_string = fp.read()
config = fix_and_load_json(json_string)
if self.bos_id is None:
self.bos_id = config.get("bos_token_id")
if self.eos_id is None:
self.eos_id = config.get("eos_token_id")
elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
from sentencepiece import SentencePieceProcessor
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
self.backend = "sentencepiece"
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
else:
raise NotImplementedError
# NOTE: A temporary fix until it's resolved on Tokenizers side.
# LlaMA tokenizer strips leading spaces if to decode a single token at a time.
# https://github.com/huggingface/transformers/issues/31643
self.apply_decoding_fix = None
if (config_path := checkpoint_dir / "tokenizer_config.json").is_file():
with open(config_path, encoding="utf-8") as fp:
self.apply_decoding_fix = "LlamaTokenizer" in json.load(fp)["tokenizer_class"]
@property
def vocab_size(self) -> int:
if self.backend == "huggingface":
return self.processor.get_vocab_size(with_added_tokens=False)
if self.backend == "sentencepiece":
return self.processor.vocab_size()
raise RuntimeError
def token_to_id(self, token: str) -> int:
if self.backend == "huggingface":
id_ = self.processor.token_to_id(token)
elif self.backend == "sentencepiece":
id_ = self.processor.piece_to_id(token)
else:
raise RuntimeError
if id_ is None:
raise ValueError(f"token {token!r} not found in the collection.")
return id_
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
return False
with open(tokenizer_config_path, encoding="utf-8") as fp:
config = json.load(fp)
# for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only
# `PreTrainedTokenizerFast`
if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")):
return True
if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"):
return True
if "add_bos_token" in config:
return config["add_bos_token"]
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
# ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
return config.get("tokenizer_class") == "LlamaTokenizer"
def encode(
self,
string: str,
device: Optional[torch.device] = None,
bos: Optional[bool] = None,
eos: bool = False,
max_length: int = -1,
) -> torch.Tensor:
if self.backend == "huggingface":
tokens = self.processor.encode(string).ids
elif self.backend == "sentencepiece":
tokens = self.processor.encode(string)
else:
raise RuntimeError(f"`{self.backend}` is not supported.")
if tokens is None:
raise ValueError("`self.processor` returned tokens of None value.")
if bos or (bos is None and self.use_bos):
if self.bos_id is None:
raise NotImplementedError("This tokenizer does not have a defined bos token.")
if not tokens or tokens[0] != self.bos_id:
tokens = [self.bos_id] + tokens
# if the processor misbehaves and adds `bos` token no matter what
elif tokens and tokens[0] == self.bos_id:
tokens = tokens[1:]
if eos and (not tokens or tokens[-1] != self.eos_id):
tokens = tokens + [self.eos_id]
# if the processor misbehaves and adds `eos` token no matter what
elif tokens and tokens[-1] == self.eos_id:
tokens = tokens[:-1]
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
if len(tokens) == 1 and self.apply_decoding_fix:
dummy_token_id = 33 # \x1e
dummy_token = self.processor.decode([dummy_token_id])
if dummy_token != "\x1e":
dummy_token_id = 165 # \x1e is different in salamandra tokenizers
dummy_token = self.processor.decode([dummy_token_id])
return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :]
return self.processor.decode(tokens)
def decode_stream(
self, token_stream: Iterable[torch.Tensor], device: Optional[torch.device] = None
) -> Iterator[str]:
if self.backend == "huggingface":
try:
for token in token_stream:
yield self.decode(token)
except KeyboardInterrupt:
return
elif self.backend == "sentencepiece":
# TODO: Is there a way to not have to do this?
# This may actually affect our tokens per second.
# sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens
# meaning that we need to decode everything each time
so_far = torch.tensor([], dtype=torch.long, device=device)
decoded_so_far = ""
try:
for token in token_stream:
so_far = so_far.to(device=token.device)
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = self.decode(so_far)
yield decoded_new[len(decoded_so_far) :]
decoded_so_far = decoded_new
except KeyboardInterrupt:
return
else:
raise NotImplementedError(self.backend)
================================================
FILE: litgpt/types.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Type aliases used across LitGPT modules."""
from typing import Literal
# Logger-related types
LoggerChoice = Literal["csv", "tensorboard", "wandb", "mlflow", "litlogger"]
"""Valid logger choices for experiment tracking.
Available options:
- "csv": Local CSV file logging (default for most scripts)
- "tensorboard": TensorBoard visualization (default for pretrain)
- "wandb": Weights & Biases cloud tracking
- "mlflow": MLflow experiment tracking
- "litlogger": Lightning.ai native tracking
"""
================================================
FILE: litgpt/utils.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Utility functions for training and inference."""
import inspect
import json
import math
import os
import pickle
import random
import re
import shutil
import subprocess
import sys
import warnings
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
import lightning as L
import psutil
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.cli import instantiate_class
from lightning.pytorch.loggers import MLFlowLogger, WandbLogger
from packaging import version
from torch.serialization import normalize_storage_type
from typing_extensions import Self
from litgpt.constants import (
_LITLOGGER_AVAILABLE,
_SUPPORTED_LOGGERS,
_THUNDER_AVAILABLE,
)
from litgpt.types import LoggerChoice
if TYPE_CHECKING:
from litgpt import GPT, Config
def init_out_dir(out_dir: Path) -> Path:
if not isinstance(out_dir, Path):
out_dir = Path(out_dir)
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
return out_dir
def find_resume_path(resume: Union[bool, Literal["auto"], Path], out_dir: Path) -> Optional[Path]:
if not resume or isinstance(resume, Path):
return resume
resume_path = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1])), default=None)
if resume == "auto":
return resume_path
if resume is True and resume_path is None:
raise FileNotFoundError(
f"You passed `--resume=True`, but no checkpoint file was found in `--out_dir={out_dir}`."
)
return resume_path
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
total = 0
for p in module.parameters():
if requires_grad is None or p.requires_grad == requires_grad:
if hasattr(p, "quant_state"):
# bitsandbytes 4bit layer support
total += math.prod(p.quant_state.shape)
else:
total += p.numel()
return total
def reset_parameters(module: nn.Module) -> None:
"""Calls `reset_parameters` on the module and all its submodules."""
for mod in module.modules():
if callable(getattr(mod, "reset_parameters", None)):
mod.reset_parameters()
def check_valid_checkpoint_dir(
checkpoint_dir: Path,
model_filename: str = "lit_model.pth",
verbose: bool = True,
raise_error: bool = False,
ignore_tokenizer_files: bool = False,
) -> None:
files = {
model_filename: (checkpoint_dir / model_filename).is_file(),
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
}
if not ignore_tokenizer_files:
files.update(
{
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
)
if checkpoint_dir.is_dir():
if all(files.values()):
# we're good
return
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
else:
problem = " is not a checkpoint directory"
# list locally available checkpoints
available = list(Path("checkpoints").glob("*/*"))
if available:
options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
extra = f"\nYou have downloaded locally:{options}\n"
else:
extra = ""
if verbose:
error_message = (
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
"\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
f"{extra}\nSee all download options by running:\n litgpt download"
)
print(error_message, file=sys.stderr)
if raise_error:
raise FileNotFoundError(f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}.")
else:
raise SystemExit(1)
class SavingProxyForStorage:
def __init__(self, obj, saver, protocol_version=5):
self.protocol_version = protocol_version
self.saver = saver
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
raise TypeError(f"expected storage, not {type(obj)}")
# this logic is taken from PyTorch 2.0+ torch/serialization.py
if isinstance(obj, torch.storage.TypedStorage):
# PT upstream wants to deprecate this eventually...
storage = obj._untyped_storage
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage_key = saver._write_storage_and_return_key(storage)
location = torch.serialization.location_tag(storage)
self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
def __reduce_ex__(self, protocol_version):
assert False, "this should be handled with out of band"
class SavingProxyForTensor:
def __init__(self, tensor, saver, protocol_version=5):
self.protocol_version = protocol_version
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
# for Tensors with Python attributes
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), (
"Please check for updates"
)
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
else:
(storage, *other_reduce_args) = reduce_args
assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), (
"Please check for updates"
)
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (storage_proxy, *other_reduce_args)
def __reduce_ex__(self, protocol_version):
if protocol_version != self.protocol_version:
raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
return self.reduce_ret_fn, self.reduce_args
class IncrementalPyTorchPickler(pickle.Pickler):
def __init__(self, saver, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage_dtypes = {}
self.saver = saver
self.id_map = {}
# this logic is taken from PyTorch 2.0+ torch/serialization.py
def persistent_id(self, obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, SavingProxyForStorage):
return obj.storage_info
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in self.storage_dtypes:
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
"Cannot save multiple tensors or storages that view the same data as different types"
)
else:
self.storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = self.id_map.get(storage._cdata)
if storage_key is None:
storage_key = self.saver._write_storage_and_return_key(storage)
self.id_map[storage._cdata] = storage_key
location = torch.serialization.location_tag(storage)
return ("storage", storage_type, storage_key, location, storage_numel)
return None
class incremental_save:
def __init__(self, name):
self.name = name
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
self.protocol_version = 2
def __enter__(self):
return self
def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self, protocol_version=self.protocol_version)
raise TypeError(f"can only store tensors early, not {type(tensor)}")
def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=self.protocol_version)
pickler.dump(obj)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
self.has_saved = True
def _write_storage_and_return_key(self, storage):
if self.has_saved:
raise RuntimeError("have already saved")
key = self.next_key
self.next_key += 1
name = f"data/{key}"
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
current_version = version.parse(torch.__version__)
threshold_version = version.parse("2.2.2")
if current_version <= threshold_version:
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
else:
self.zipfile.write_record(name, storage, num_bytes)
return key
def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()
T = TypeVar("T")
def chunked_cross_entropy(
logits: Union[torch.Tensor, List[torch.Tensor]],
targets: torch.Tensor,
chunk_size: int = 128,
ignore_index: int = -100,
) -> torch.Tensor:
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
# the memory usage in fine-tuning settings with low number of parameters.
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
# the memory spike's magnitude
# lm_head was chunked (we are fine-tuning)
if isinstance(logits, list):
# don't want to chunk cross entropy
if chunk_size == 0:
logits = torch.cat(logits, dim=1)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)
# chunk cross entropy
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# See [non_masked_elems div note]
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))
# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if chunk_size == 0:
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)
# lm_head wasn't chunked, chunk cross entropy
logit_chunks = logits.split(chunk_size)
target_chunks = targets.split(chunk_size)
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# [non_masked_elems div note]:
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
# results in a python int which is then passed back to torch division. By using the
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
for checkpoint_name, attribute_name in mapping.items():
full_checkpoint_name = prefix + checkpoint_name
if full_checkpoint_name in state_dict:
full_attribute_name = prefix + attribute_name
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
return state_dict
def get_default_supported_precision(training: bool) -> str:
"""
Return the default precision that is supported by the hardware: either `bf16` or `16`.
Args:
training: If True, returns '-mixed' version of the precision; if False, returns '-true' version.
Returns:
The default precision that is suitable for the task and is supported by the hardware.
"""
import torch
if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return "bf16-mixed" if training else "bf16-true"
else:
return "16-mixed" if training else "16-true"
return "bf16-mixed" if training else "bf16-true"
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
elif isinstance(fabric.strategy, ModelParallelStrategy):
state_dict = torch.load(checkpoint_path, mmap=True)
load_from_full_model_state_dict(
model=model,
full_sd=state_dict,
device=fabric.device,
strict=strict,
cpu_offload=True,
)
else:
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)
def load_checkpoint_update(
fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True
) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
else:
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
adapter_cp = lazy_load(adapter_path)
state_dict.update(adapter_cp)
model.load_state_dict(state_dict, strict=strict)
def load_from_full_model_state_dict(
model: torch.nn.Module,
full_sd: Dict[str, Any],
device: torch.device,
strict: bool = False,
cpu_offload: bool = False,
):
from torch.distributed._tensor import distribute_tensor
meta_sharded_sd = model.state_dict()
sharded_sd = {}
print(meta_sharded_sd.keys())
for param_name, full_tensor in full_sd.items():
if "norm" not in param_name and "wte" not in param_name and "ln_f" not in param_name:
param_name = param_name.replace(".weight", ".linear.weight")
param_name = param_name.replace(".bias", ".linear.bias")
else:
param_name = param_name
print(param_name)
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
# this assumes that all samples have a fixed length equal to the block size
# which is most likely false during finetuning
flops_per_seq = flops_per_token * max_seq_length
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
return flops_per_seq + attn_flops_per_seq
def estimate_flops(model: "GPT", training: bool) -> int:
"""Measures estimated FLOPs for MFU.
Refs:
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
"""
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
n_trainable_params = num_parameters(model, requires_grad=True)
trainable_flops = flops_per_param(
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
)
# forward + backward + gradients (assumes no gradient accumulation)
ops_per_step = 3 if training else 1
n_frozen_params = num_parameters(model, requires_grad=False)
frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
# forward + backward
frozen_ops_per_step = 2 if training else 1
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
class CycleIterator:
"""An iterator that cycles through an iterable indefinitely.
Example:
>>> iterator = CycleIterator([1, 2, 3])
>>> [next(iterator) for _ in range(5)]
[1, 2, 3, 1, 2]
Note:
Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
"""
def __init__(self, iterable: Iterable) -> None:
self.iterable = iterable
self.epoch = 0
self._iterator = None
def __next__(self) -> Any:
if self._iterator is None:
self._iterator = iter(self.iterable)
try:
return next(self._iterator)
except StopIteration:
self._iterator = iter(self.iterable)
self.epoch += 1
return next(self._iterator)
def __iter__(self) -> Self:
return self
def copy_config_files(source_dir: Path, out_dir: Path) -> None:
"""Copies the specified configuration and tokenizer files into the output directory."""
config_files = ["config.json", "generation_config.json", "model_config.yaml"]
tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
for file_name in config_files + tokenizer_files:
src_path = source_dir / file_name
if src_path.exists():
shutil.copy(src_path, out_dir)
def CLI(*args: Any, **kwargs: Any) -> Any:
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
return CLI(*args, **kwargs)
def capture_hparams() -> Dict[str, Any]:
"""Captures the local variables ('hyperparameters') from where this function gets called."""
caller_frame = inspect.currentframe().f_back
locals_of_caller = caller_frame.f_locals
hparams = {}
for name, value in locals_of_caller.items():
if value is None or isinstance(value, (int, float, str, bool, Path)):
hparams[name] = value
elif is_dataclass(value):
hparams[name] = asdict(value)
else:
hparams[name] = str(value)
return hparams
def save_config(config: "Config", checkpoint_dir: Path) -> None:
config_dict = asdict(config)
with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
yaml.dump(config_dict, fp)
def parse_devices(devices: Union[str, int]) -> int:
if devices in (-1, "auto"):
return torch.cuda.device_count() or 1
if isinstance(devices, int) and devices > 0:
return devices
raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
def choose_logger(
logger_name: LoggerChoice,
out_dir: Path,
name: str,
log_interval: int = 1,
log_args: Optional[Dict] = None,
resume: Optional[bool] = None,
**kwargs: Any,
):
if logger_name == "csv":
return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs)
if logger_name == "tensorboard":
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs)
if logger_name == "wandb":
project = log_args.pop("project", name)
run = log_args.pop("run", os.environ.get("WANDB_RUN_NAME"))
group = log_args.pop("group", os.environ.get("WANDB_RUN_GROUP"))
return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs)
if logger_name == "mlflow":
return MLFlowLogger(experiment_name=name, **kwargs)
if logger_name == "litlogger":
if not _LITLOGGER_AVAILABLE:
raise ModuleNotFoundError(_LITLOGGER_AVAILABLE)
from lightning.pytorch.loggers import LitLogger
# Extract litlogger-specific args
teamspace = log_args.pop("teamspace", None) if log_args else None
metadata = log_args.pop("metadata", None) if log_args else None
log_model = log_args.pop("log_model", False) if log_args else False
save_logs = log_args.pop("save_logs", True) if log_args else True
checkpoint_name = log_args.pop("checkpoint_name", None) if log_args else None
return LitLogger(
root_dir=(out_dir / "logs"),
name=name,
teamspace=teamspace,
metadata=metadata,
log_model=log_model,
save_logs=save_logs,
checkpoint_name=checkpoint_name,
**kwargs,
)
raise ValueError(
f"`--logger_name={logger_name}` is not a valid option. Choose from {', '.join(_SUPPORTED_LOGGERS)}."
)
def get_argument_names(cls):
sig = inspect.signature(cls.__init__)
return {
name
for name, param in sig.parameters.items()
if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
}
def instantiate_bnb_optimizer(optimizer, model_parameters):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")
import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model_parameters)
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {key: optimizer["init_args"][key] for key in optim_args & optimizer["init_args"].keys()}
optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
return optimizer
def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
# Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script:
# bnb.optim.AdamW8bit
# grokadamw.GrokAdamW
# torch.optim.RMSprop
if isinstance(optimizer, str):
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
else:
class_module, class_name = "torch.optim", optimizer
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)
valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}
optimizer = optimizer_cls(model_parameters, **kwargs)
elif isinstance(optimizer, dict):
optimizer = dict(optimizer)
class_module, class_name = optimizer["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
optimizer_cls = getattr(module, class_name)
valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}
optimizer["init_args"].update(kwargs)
optimizer = instantiate_class(model_parameters, optimizer)
else:
raise ValueError(f'Unrecognized "optimizer" value: {optimizer}')
return optimizer
def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
new_checkpoint_dir = "checkpoints" / checkpoint_dir
should_return_new_dir = (
not checkpoint_dir.is_dir()
and checkpoint_dir.parts[0] != "checkpoints"
and not checkpoint_dir.is_absolute()
and new_checkpoint_dir.exists()
)
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660):
"""
Checks the file size and raises a warning if it exceeds the size_limit.
The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660
"""
size = 0.0
if os.path.exists(checkpoint_path):
size = os.path.getsize(checkpoint_path)
if size > size_limit and str(device) == "cpu":
warnings.warn(
f"The file size of {checkpoint_path} is over {size_limit / 1024 / 1024 / 1024:.1f} GB. Using a model "
"with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU."
)
return size
def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False):
from litgpt.scripts.download import download_from_hub # moved here due to circular import issue
checkpoint_dir = extend_checkpoint_dir(Path(model_name))
try:
check_valid_checkpoint_dir(
checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files
)
except FileNotFoundError as e:
if access_token is None:
access_token = os.getenv("HF_TOKEN")
if checkpoint_dir.parts[0] != "checkpoints" and not checkpoint_dir.is_absolute():
download_from_hub(repo_id=str(model_name), access_token=access_token)
checkpoint_dir = Path("checkpoints") / checkpoint_dir
else:
raise e
return checkpoint_dir
def check_nvlink_connectivity(fabric=None):
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.
This function delegates to vendor-specific implementations based on
the detected GPU vendor.
"""
if fabric is not None:
custom_print = fabric.print
else:
custom_print = print
if os.getenv("RANK", "0") == "0":
try:
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(0)
gpu_name = device_properties.name.lower()
if "nvidia" in gpu_name:
_check_nvidia_connectivity(custom_print)
elif "advanced micro devices" in gpu_name or "amd" in gpu_name:
_check_amd_connectivity(custom_print)
else:
custom_print(f"Unrecognized GPU vendor: {device_properties.name}")
else:
custom_print("No GPUs available")
except Exception as e:
custom_print(f"An error occurred while checking GPU connectivity: {e}")
def _check_nvidia_connectivity(custom_print):
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return
lines = result.stdout.strip().split("\n")
start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None)
if start_index is None:
custom_print("Failed to parse nvidia-smi output")
return
headers_line = lines[start_index]
headers = headers_line.split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])
all_nvlink = True
for line in lines[start_index + 1 : start_index + 1 + gpu_count]:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break
if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
def _check_amd_connectivity(custom_print):
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run rocm-smi")
return
lines = result.stdout.strip().split("\n")
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
return
header_line = lines[gpu_header_index - 1]
headers = header_line.strip().split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])
gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
gpu_lines.append(line.strip())
if len(gpu_lines) != gpu_count:
custom_print("Mismatch in GPU count when parsing rocm-smi output")
return
all_xgmi = True
for line in gpu_lines:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
for conn in connections:
if conn not in ("XGMI", "0"):
all_xgmi = False
break
if not all_xgmi:
break
if all_xgmi:
custom_print("All GPUs are fully connected via XGMI.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
def fix_and_load_json(s):
# Remove trailing commas before } or ]
s = re.sub(r",(\s*[}\]])", r"\1", s)
# Insert missing commas between properties
# Match positions where a value is followed by a newline and then a quote without a comma
pattern = r'(?<=[}\]0-9truefalsenull"])\s*(\n\s*)"'
replacement = r',\1"'
s = re.sub(pattern, replacement, s)
# Now try to parse the JSON
try:
return json.loads(s)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON after fixing: {e}")
def create_finetuning_performance_report(training_time, token_counts, device_type):
tok_sec = token_counts["raw_tokens_plus_prompt_template_and_padding"] / training_time
output = f"""
| ------------------------------------------------------
| Token Counts
| - Input Tokens : {token_counts["raw_tokens"]:>5}
| - Tokens w/ Prompt : {token_counts["raw_tokens_plus_prompt_template"]:>5}
| - Total Tokens (w/ Padding) : {token_counts["raw_tokens_plus_prompt_template_and_padding"]:>5}
| -----------------------------------------------------
| Performance
| - Training Time : {training_time:.2f} s
| - Tok/sec : {tok_sec:.2f} tok/s
| -----------------------------------------------------
"""
if device_type == "cuda":
memory_used = torch.cuda.max_memory_allocated() / 1e9
output += "| Memory Usage \n"
output += f"| - Memory Used : {memory_used:.02f} GB \n"
output += "-------------------------------------------------------\n"
return output
def select_sft_generate_example(eval, data):
if eval.evaluate_example == "first":
if len(data.test_dataset.data):
instruction = data.test_dataset.data[0]["instruction"]
else:
instruction = data.train_dataset.data[0]["instruction"]
elif eval.evaluate_example == "random":
if len(data.test_dataset.data):
random_idx = random.randint(0, len(data.test_dataset.data) - 1)
instruction = data.test_dataset.data[random_idx]["instruction"]
else:
random_idx = random.randint(0, len(data.train_dataset.data) - 1)
instruction = data.train_dataset.data[random_idx]["instruction"]
elif isinstance(eval.evaluate_example, int):
index = eval.evaluate_example
if len(data.test_dataset.data) > index:
instruction = data.test_dataset.data[index]["instruction"]
elif len(data.train_dataset.data) > index:
instruction = data.train_dataset.data[index]["instruction"]
else:
raise IndexError(f"Index {index} is out of range for both test and training datasets.")
else:
raise ValueError(f"Unknown evaluation example type: {eval.evaluate_example}")
return instruction
def _RunIf(thunder: bool = False, **kwargs):
import pytest
from lightning.fabric.utilities.testing import _runif_reasons
reasons, marker_kwargs = _runif_reasons(**kwargs)
if thunder and not _THUNDER_AVAILABLE:
# if we require Thunder, but it's not available, we should skip
reasons.append("Thunder")
return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs)
def kill_process_tree(pid: int):
"""
Kill a process and all its child processes given the parent PID.
"""
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in children:
child.kill()
parent.kill()
except psutil.NoSuchProcess:
pass # Process already exited
================================================
FILE: pyproject.toml
================================================
[build-system]
build-backend = "setuptools.build_meta"
requires = [
"setuptools>=68.2.2",
"wheel>=0.41.2",
]
[project]
name = "litgpt"
version = "0.5.12"
description = "Hackable implementation of state-of-the-art open-source LLMs"
readme = "README.md"
license = { file = "LICENSE" }
authors = [
{ name = "Lightning AI", email = "contact@lightning.ai" },
]
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
dependencies = [
# download models:
"huggingface-hub>=0.30,<1.4",
"jsonargparse[signatures]>=4.37,<=4.41; python_version>='3.10'", # required to work with Python >=3.10
"lightning>=2.6.1",
"psutil==7.1.3",
"safetensors>=0.4.3",
# tokenization in most models:
"tokenizers>=0.21",
"torch>=2.7",
# convert_hf_checkpoint
"tqdm>4.66",
]
optional-dependencies.compiler = [
# compilaton:
"lightning-thunder>=0.2.dev20250119; python_version>='3.10' and sys_platform=='linux'",
]
optional-dependencies.extra = [
"bitsandbytes>=0.42,<0.43; sys_platform=='darwin'",
# quantization:
"bitsandbytes>=0.45.2,<0.50; sys_platform=='linux' or sys_platform=='win32'",
# litgpt.evaluate:
"datasets>=2.18,<4",
# download:
"huggingface-hub[hf-transfer]>=0.21",
"litdata==0.2.59",
# litgpt logging:
"litlogger>=0.1.7",
# litgpt.deploy:
"litserve>0.2",
# lm-eval: pinned <0.4.9.1 due to trust_remote_code issues with datasets like logiqa.
# See: https://github.com/EleutherAI/lm-evaluation-harness/issues/3171
"lm-eval>=0.4.2,<0.4.9.1",
# litgpt.data.prepare_starcoder.py:
"pandas>=1.9",
"pyarrow>=15.0.2",
# litgpt.data:
"requests>=2.31",
# llama-based models:
"sentencepiece>=0.2",
# litgpt.pretrain:
"tensorboard>=2.14",
"torchmetrics>=1.3.1",
"transformers>=4.51.3,<4.57",
# litdata, only on non-Windows:
"uvloop>=0.2; sys_platform!='win32'",
# litgpt.data.prepare_slimpajama.py:
"zstandard>=0.22",
]
optional-dependencies.test = [
"einops>=0.7",
"protobuf>=4.23.4",
"pytest>=8.1.1",
"pytest-benchmark>=5.1",
"pytest-dependency>=0.6",
"pytest-rerunfailures>=14",
"pytest-timeout>=2.3.1",
]
urls.documentation = "https://github.com/lightning-AI/litgpt/tutorials"
urls.homepage = "https://github.com/lightning-AI/litgpt"
scripts.litgpt = "litgpt.__main__:main"
[tool.setuptools.packages.find]
include = [
"litgpt",
"litgpt.*",
]
exclude = [ ]
[tool.setuptools.package-data]
litgpt = [
"LICENSE",
"README.md",
]
[tool.ruff]
target-version = "py38"
line-length = 120
exclude = [
"build",
"dist",
"docs",
]
lint.select = [
"E",
"F", # see: https://pypi.org/project/pyflakes
"I", # implementation for isort
"UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up
"W", # see: https://pypi.org/project/pycodestyle
]
#extend-select = [
# "C4", # see: https://pypi.org/project/flake8-comprehensions
# "PT", # see: https://pypi.org/project/flake8-pytest-style
# "RET", # see: https://pypi.org/project/flake8-return
# "SIM", # see: https://pypi.org/project/flake8-simplify
#]
lint.ignore = [
"E501", # Line too long
"E731", # Do not assign a lambda expression, use a def
"E741", # todo: Ambiguous variable name
"F841", # todo: Local variable is assigned to but never used
]
# Use Google-style docstrings.
lint.pydocstyle.convention = "google"
[tool.codespell]
#skip = '*.py'
quiet-level = 3
ignore-words-list = """
tral, \
Rockerfeller
"""
[tool.pytest.ini_options]
addopts = [
"--strict-markers",
#"--doctest-modules",
"--color=yes",
"--disable-pytest-warnings",
]
================================================
FILE: tests/conftest.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import shutil
import sys
from pathlib import Path
from typing import List, Optional
import pytest
import torch
# support running without installing as a package, adding extensions to the Python path
wd = Path(__file__).parent.parent.resolve()
if wd.is_dir():
sys.path.append(str(wd))
else:
import warnings
warnings.warn(f"Could not find extensions directory at {wd}")
@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "model_config.yaml").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir
class TensorLike:
def __eq__(self, other):
return isinstance(other, torch.Tensor)
@pytest.fixture()
def tensor_like():
return TensorLike()
class FloatLike:
def __eq__(self, other):
return not isinstance(other, int) and isinstance(other, float)
@pytest.fixture()
def float_like():
return FloatLike()
@pytest.fixture(autouse=True)
def restore_default_dtype():
# just in case
torch.set_default_dtype(torch.float32)
@pytest.fixture(autouse=True)
def destroy_process_group():
yield
import torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@pytest.fixture
def turn_off_tf32_and_set_seed(monkeypatch):
monkeypatch.setenv("NVIDIA_TF32_OVERRIDE", "0")
torch.manual_seed(42)
yield
torch.seed()
class MockTokenizer:
"""A dummy tokenizer that encodes each character as its ASCII code."""
bos_id = 0
eos_id = 1
def encode(self, text: str, bos: Optional[bool] = None, eos: bool = False, max_length: int = -1) -> torch.Tensor:
output = []
if bos:
output.append(self.bos_id)
output.extend([ord(c) for c in text])
if eos:
output.append(self.eos_id)
output = output[:max_length] if max_length > 0 else output
return torch.tensor(output)
def decode(self, tokens: torch.Tensor) -> str:
return "".join(chr(int(t)) for t in tokens.tolist())
@pytest.fixture()
def mock_tokenizer():
return MockTokenizer()
@pytest.fixture()
def alpaca_path(tmp_path):
file = Path(__file__).parent / "data" / "_fixtures" / "alpaca.json"
shutil.copyfile(file, tmp_path / "alpaca.json")
return tmp_path / "alpaca.json"
@pytest.fixture()
def dolly_path(tmp_path):
file = Path(__file__).parent / "data" / "_fixtures" / "dolly.json"
shutil.copyfile(file, tmp_path / "dolly.json")
return tmp_path / "dolly.json"
@pytest.fixture()
def longform_path(tmp_path):
path = tmp_path / "longform"
path.mkdir()
for split in ("train", "val"):
file = Path(__file__).parent / "data" / "_fixtures" / f"longform_{split}.json"
shutil.copyfile(file, path / f"{split}.json")
return path
# https://github.com/Lightning-AI/lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/tests/tests_fabric/conftest.py#L140
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
initial_size = len(items)
conditions = []
filtered, skipped = 0, 0
options = {"standalone": "PL_RUN_STANDALONE_TESTS", "min_cuda_gpus": "RUN_ONLY_CUDA_TESTS"}
if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1":
# special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.
# by deleting the key, we avoid filtering out the CPU tests
del options["min_cuda_gpus"]
for kwarg, env_var in options.items():
# this will compute the intersection of all tests selected per environment variable
if os.getenv(env_var, "0") == "1":
conditions.append(env_var)
for i, test in reversed(list(enumerate(items))): # loop in reverse, since we are going to pop items
already_skipped = any(marker.name == "skip" for marker in test.own_markers)
if already_skipped:
# the test was going to be skipped anyway, filter it out
items.pop(i)
skipped += 1
continue
has_runif_with_kwarg = any(
marker.name == "skipif" and marker.kwargs.get(kwarg) for marker in test.own_markers
)
if not has_runif_with_kwarg:
# the test has `@_RunIf(kwarg=True)`, filter it out
items.pop(i)
filtered += 1
if config.option.verbose >= 0 and (filtered or skipped):
writer = config.get_terminal_writer()
writer.write(
f"\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the"
f" filters {conditions}.\n{skipped} tests are marked as unconditional skips.\nIn total,"
f" {len(items)} tests will run.\n",
flush=True,
bold=True,
purple=True, # oh yeah, branded pytest messages
)
for test in items:
if "test_hf_for_nemo" in test.nodeid and "Qwen/Qwen2.5-7B-Instruct" in test.nodeid:
test.add_marker(
# Don't use `raises=TypeError` because the actual exception is
# wrapped inside `torch._dynamo.exc.BackendCompilerFailed`,
# which prevents pytest from recognizing it as a TypeError.
pytest.mark.xfail(
reason="currently not working, see https://github.com/Lightning-AI/lightning-thunder/issues/2085",
)
)
================================================
FILE: tests/convert/__init__.py
================================================
================================================
FILE: tests/convert/test_hf_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from unittest import mock
import pytest
import torch
from litgpt import Config
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama, qkv_reassemble
def test_llama2_70b_conversion():
shapes = {
"model.embed_tokens.weight": (32000, 8192),
"model.layers.0.input_layernorm.weight": (8192,),
"model.layers.0.mlp.down_proj.weight": (8192, 28672),
"model.layers.0.mlp.gate_proj.weight": (28672, 8192),
"model.layers.0.mlp.up_proj.weight": (28672, 8192),
"model.layers.0.post_attention_layernorm.weight": (8192,),
"model.layers.0.self_attn.q_proj.weight": (8192, 8192),
"model.layers.0.self_attn.k_proj.weight": (1024, 8192),
"model.layers.0.self_attn.v_proj.weight": (1024, 8192),
"model.layers.0.self_attn.o_proj.weight": (8192, 8192),
"model.layers.1.input_layernorm.weight": (8192,),
"model.layers.1.mlp.down_proj.weight": (8192, 28672),
"model.layers.1.mlp.gate_proj.weight": (28672, 8192),
"model.layers.1.mlp.up_proj.weight": (28672, 8192),
"model.layers.1.post_attention_layernorm.weight": (8192,),
"model.layers.1.self_attn.o_proj.weight": (8192, 8192),
"model.layers.2.input_layernorm.weight": (8192,),
"model.layers.2.mlp.down_proj.weight": (8192, 28672),
"model.layers.2.mlp.gate_proj.weight": (28672, 8192),
"model.layers.2.mlp.up_proj.weight": (28672, 8192),
"model.layers.2.post_attention_layernorm.weight": (8192,),
"model.layers.2.self_attn.o_proj.weight": (8192, 8192),
"model.layers.3.input_layernorm.weight": (8192,),
"model.layers.3.mlp.down_proj.weight": (8192, 28672),
"model.layers.3.mlp.gate_proj.weight": (28672, 8192),
"model.layers.3.mlp.up_proj.weight": (28672, 8192),
"model.layers.3.post_attention_layernorm.weight": (8192,),
"model.layers.3.self_attn.o_proj.weight": (8192, 8192),
"model.layers.4.input_layernorm.weight": (8192,),
"model.layers.4.mlp.down_proj.weight": (8192, 28672),
"model.layers.4.mlp.gate_proj.weight": (28672, 8192),
"model.layers.4.mlp.up_proj.weight": (28672, 8192),
"model.layers.4.post_attention_layernorm.weight": (8192,),
"model.layers.4.self_attn.o_proj.weight": (8192, 8192),
"model.layers.5.mlp.gate_proj.weight": (28672, 8192),
"model.layers.5.self_attn.o_proj.weight": (8192, 8192),
}
config = Config.from_name("Llama-2-70b-hf")
holder = {}
qkv_weights = {}
with torch.device("meta"):
weight_map = {k: torch.empty(s) for k, s in shapes.items()}
copy_weights_hf_llama(config, qkv_weights, holder, weight_map)
# NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v`
assert len(qkv_weights) == 1
# there are no loaded qkv weights
assert all(v is None for qkv in qkv_weights.values() for v in qkv)
# the shapes are correct
holder = {k: tuple(t.shape) for k, t in holder.items()}
assert holder == {
"transformer.h.0.attn.qkv.weight": (10240, 8192),
"transformer.h.0.attn.proj.weight": (8192, 8192),
"transformer.h.0.mlp.fc_1.weight": (28672, 8192),
"transformer.h.0.mlp.fc_2.weight": (28672, 8192),
"transformer.h.0.mlp.proj.weight": (8192, 28672),
"transformer.h.0.norm_1.weight": (8192,),
"transformer.h.0.norm_2.weight": (8192,),
"transformer.h.1.attn.proj.weight": (8192, 8192),
"transformer.h.1.mlp.fc_1.weight": (28672, 8192),
"transformer.h.1.mlp.fc_2.weight": (28672, 8192),
"transformer.h.1.mlp.proj.weight": (8192, 28672),
"transformer.h.1.norm_1.weight": (8192,),
"transformer.h.1.norm_2.weight": (8192,),
"transformer.h.2.attn.proj.weight": (8192, 8192),
"transformer.h.2.mlp.fc_1.weight": (28672, 8192),
"transformer.h.2.mlp.fc_2.weight": (28672, 8192),
"transformer.h.2.mlp.proj.weight": (8192, 28672),
"transformer.h.2.norm_1.weight": (8192,),
"transformer.h.2.norm_2.weight": (8192,),
"transformer.h.3.attn.proj.weight": (8192, 8192),
"transformer.h.3.mlp.fc_1.weight": (28672, 8192),
"transformer.h.3.mlp.fc_2.weight": (28672, 8192),
"transformer.h.3.mlp.proj.weight": (8192, 28672),
"transformer.h.3.norm_1.weight": (8192,),
"transformer.h.3.norm_2.weight": (8192,),
"transformer.h.4.attn.proj.weight": (8192, 8192),
"transformer.h.4.mlp.fc_1.weight": (28672, 8192),
"transformer.h.4.mlp.fc_2.weight": (28672, 8192),
"transformer.h.4.mlp.proj.weight": (8192, 28672),
"transformer.h.4.norm_1.weight": (8192,),
"transformer.h.4.norm_2.weight": (8192,),
"transformer.h.5.attn.proj.weight": (8192, 8192),
"transformer.h.5.mlp.fc_1.weight": (28672, 8192),
"transformer.wte.weight": (32000, 8192),
"lm_head.weight": (32000, 8192), # due to weight tying lm_head is in the converted weights
}
@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2"))
def test_convert_hf_checkpoint(tmp_path, model_name):
with pytest.raises(ValueError, match="to contain .bin"):
convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name)
bin_file = tmp_path / "foo.bin"
bin_file.touch()
with mock.patch("litgpt.scripts.convert_hf_checkpoint.lazy_load") as load:
# bypass if-statement for weight tying
if model_name == "Llama-2-7b-hf":
load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))}
convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name)
load.assert_called_with(bin_file)
assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "model_config.yaml", "lit_model.pth"}
# ensure that the config dict can be loaded
config = Config.from_file(tmp_path / "model_config.yaml")
assert isinstance(config, Config)
def test_qkv_reassemble():
# MHA
config = Config(n_embd=4, n_head=4)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[16, 17, 18, 19], # key
[32, 33, 34, 35], # value
[4, 5, 6, 7], # query
[20, 21, 22, 23], # key
[36, 37, 38, 39], # value
[8, 9, 10, 11], # query
[24, 25, 26, 27], # key
[40, 41, 42, 43], # value
[12, 13, 14, 15], # query
[28, 29, 30, 31], # key
[44, 45, 46, 47], # value
]
)
qkv = qkv_reassemble(qkv_interleaved, config)
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # key
[28, 29, 30, 31], # key
[32, 33, 34, 35], # value
[36, 37, 38, 39], # value
[40, 41, 42, 43], # value
[44, 45, 46, 47], # value
]
),
)
# GQA
config = Config(n_embd=4, n_head=4, n_query_groups=2)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[16, 17, 18, 19], # key
[24, 25, 26, 27], # value
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[20, 21, 22, 23], # key
[28, 29, 30, 31], # value
]
)
qkv = qkv_reassemble(qkv_interleaved, config)
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # value
[28, 29, 30, 31], # value
]
),
)
# MQA
config = Config(n_embd=4, n_head=4, n_query_groups=1)
qkv_interleaved = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
)
qkv = qkv_reassemble(qkv_interleaved, config)
torch.testing.assert_close(
qkv,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
),
)
================================================
FILE: tests/convert/test_lit_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from dataclasses import asdict
from unittest.mock import ANY
import pytest
import torch
import yaml
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.falcon import FalconConfig, FalconForCausalLM
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM
from transformers.models.phi.configuration_phi import PhiConfig
from transformers.models.phi.modeling_phi import PhiForCausalLM
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from litgpt import GPT, Config
from litgpt.scripts.convert_lit_checkpoint import (
check_conversion_supported,
convert_lit_checkpoint,
copy_weights_falcon,
copy_weights_gemma_2,
copy_weights_gemma_3,
copy_weights_gpt_neox,
copy_weights_llama,
copy_weights_phi,
copy_weights_qwen_2_5,
qkv_reassemble,
)
from litgpt.utils import _RunIf
@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2"))
def test_convert_lit_checkpoint(tmp_path, model_name):
ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128)
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
config_path = tmp_path / "model_config.yaml"
torch.save(ours_model.state_dict(), checkpoint_path)
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
output_dir = tmp_path / "out_dir"
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
assert set(os.listdir(tmp_path)) == {"lit_model.pth", "model_config.yaml", "out_dir"}
assert os.path.isfile(output_dir / "model.pth")
# check checkpoint is unwrapped
torch.save({"model": ours_model.state_dict()}, checkpoint_path)
convert_lit_checkpoint(checkpoint_path.parent, output_dir)
converted_sd = torch.load(output_dir / "model.pth")
assert "model" not in converted_sd
@torch.inference_mode()
def test_against_falcon_40b():
ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)
theirs_config = FalconConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_kv_heads=ours_config.n_query_groups,
new_decoder_architecture=True,
parallel_attn=ours_config.parallel_residual,
bias=ours_config.bias,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_falcon(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = FalconForCausalLM(theirs_config)
# assign must be set to True for torch.testing.assert_close to pass
theirs_model.load_state_dict(theirs_state_dict, assign=True)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_gpt_neox():
ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16)
assert ours_config.padded_vocab_size == 512
theirs_config = GPTNeoXConfig(
hidden_act="gelu",
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
initializer_range=0.02,
intermediate_size=ours_config.intermediate_size,
layer_norm_eps=1e-05,
max_position_embeddings=ours_config.block_size,
rotary_emb_base=10000,
rotary_pct=ours_config.rotary_percentage,
vocab_size=ours_config.padded_vocab_size,
use_parallel_residual=ours_config.parallel_residual,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_gpt_neox(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = GPTNeoXForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.randint(0, ours_config.padded_vocab_size, size=(2, ours_config.block_size), dtype=torch.int64)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}]
)
def test_against_hf_llama2(ours_kwargs):
ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_query_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = LlamaForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1"))
def test_against_mixtral(model_name):
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
n_expert=4,
)
T = 5
theirs_config = MixtralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
num_local_experts=ours_config.n_expert,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = MixtralForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf"))
def test_against_olmo(model_name):
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
intermediate_size=86,
)
T = 5
theirs_config = OlmoConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
tie_word_embeddings=(model_name == "OLMo-1B-hf"),
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=(model_name == "OLMo-1B-hf"))
theirs_model = OlmoForCausalLM(theirs_config)
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_open_llama_3b():
ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
theirs_config = LlamaConfig(
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = LlamaForCausalLM(theirs_config)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2"))
def test_against_hf_phi(model_name):
ours_config = Config.from_name(
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
T = 5
theirs_config = PhiConfig(
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = PhiForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct",))
def test_against_hf_phi_3(model_name):
ours_config = Config.from_name(model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256)
T = 5
theirs_config = Phi3Config(
attention_bias=ours_config.bias,
head_dim=ours_config.head_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
num_key_value_heads=ours_config.n_query_groups,
pad_token_id=ours_config.padded_vocab_size - 1,
partial_rotary_factor=ours_config.rotary_percentage,
rms_norm_eps=ours_config.norm_eps,
rope_theta=ours_config.rope_base,
vocab_size=ours_config.padded_vocab_size,
)
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = Phi3ForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
def test_against_original_stablelm_zephyr_3b():
T = 5
ours_config = Config.from_name("stablelm-zephyr-3b", n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = AutoConfig.from_pretrained(
"stabilityai/stablelm-zephyr-3b",
trust_remote_code=True,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_head,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True, torch_dtype=torch.float32)
theirs_model.load_state_dict(theirs_state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config).to(device)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_model.load_state_dict(
theirs_state_dict,
strict=False,
)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-2-2b", "gemma-2-9b", "gemma-2-27b"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_2(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config).to(device)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_gemma_2(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
@pytest.mark.parametrize(
("device", "dtype"),
[
pytest.param(torch.device("cpu"), torch.float32, marks=[pytest.mark.flaky(reruns=3)]),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# todo: the reference does softmax upscaled to fp32 during attention
# additionally, the final layernorm input is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_3(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config).to(device)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
def test_check_conversion_supported_adapter():
lit_weights = {"some.key.name": ANY, "error.key.gating_factor": ANY}
with pytest.raises(NotImplementedError, match="Converting adapter"):
check_conversion_supported(lit_weights=lit_weights)
lit_weights = {"some.key.name": ANY, "error.key.adapter_bias": ANY}
with pytest.raises(NotImplementedError, match="Converting adapter"):
check_conversion_supported(lit_weights=lit_weights)
def test_check_conversion_supported_lora():
lit_weights = {"some.key.name": ANY, "error.key.lora": ANY}
with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"):
check_conversion_supported(lit_weights=lit_weights)
@torch.inference_mode()
@pytest.mark.parametrize(
"model_name",
(
"Qwen2.5-1.5B",
"Qwen2.5-Coder-1.5B",
"Qwen2.5-Math-1.5B",
"QwQ-32B-Preview",
"QwQ-32B",
"Qwen2.5-7B-Instruct-1M",
),
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_qwen_2_5(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Qwen2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.attn_bias,
tie_word_embeddings=True,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
ours_model = GPT(ours_config).to(device)
# tie weights
ours_model.lm_head.weight = ours_model.transformer.wte.weight
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_qwen_2_5(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
theirs_model = Qwen2ForCausalLM(theirs_config).to(device)
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
def test_qkv_reassemble():
# MHA
config = Config(n_embd=4, n_head=4)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # key
[28, 29, 30, 31], # key
[32, 33, 34, 35], # value
[36, 37, 38, 39], # value
[40, 41, 42, 43], # value
[44, 45, 46, 47], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[16, 17, 18, 19], # key
[32, 33, 34, 35], # value
[4, 5, 6, 7], # query
[20, 21, 22, 23], # key
[36, 37, 38, 39], # value
[8, 9, 10, 11], # query
[24, 25, 26, 27], # key
[40, 41, 42, 43], # value
[12, 13, 14, 15], # query
[28, 29, 30, 31], # key
[44, 45, 46, 47], # value
]
),
)
# GQA
config = Config(n_embd=4, n_head=4, n_query_groups=2)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # value
[28, 29, 30, 31], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[16, 17, 18, 19], # key
[24, 25, 26, 27], # value
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[20, 21, 22, 23], # key
[28, 29, 30, 31], # value
]
),
)
# MQA
config = Config(n_embd=4, n_head=4, n_query_groups=1)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
),
)
================================================
FILE: tests/convert/test_pretrained_checkpoint.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import torch
from litgpt.scripts.convert_pretrained_checkpoint import convert_pretrained_checkpoint
def test_convert_pretrained_checkpoint(tmp_path, fake_checkpoint_dir):
# Pretend we made a checkpoint from pretraining
pretrained_checkpoint = {
"model": {"some.module.weight": torch.rand(2, 2), "_orig_mod.some.other.module.weight": torch.rand(2, 2)},
"the_optimizer": "optimizer_state",
"other": 1,
}
torch.save(pretrained_checkpoint, fake_checkpoint_dir / "lit_model.pth")
convert_pretrained_checkpoint(checkpoint_dir=fake_checkpoint_dir, output_dir=(tmp_path / "converted"))
assert set(os.listdir(tmp_path / "converted")) == {
"lit_model.pth",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
}
converted_checkpoint = torch.load(tmp_path / "converted" / "lit_model.pth")
assert list(converted_checkpoint.keys()) == ["some.module.weight", "some.other.module.weight"]
================================================
FILE: tests/data/__init__.py
================================================
================================================
FILE: tests/data/_fixtures/alpaca.json
================================================
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": "1. Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
},
{
"instruction": "What are the three primary colors?",
"input": "",
"output": "The three primary colors are red, blue, and yellow."
},
{
"instruction": "Describe the structure of an atom.",
"input": "",
"output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."
},
{
"instruction": "How can we reduce air pollution?",
"input": "",
"output": "There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances."
},
{
"instruction": "Pretend you are a project manager of a construction company. Describe a time when you had to make a difficult decision.",
"input": "",
"output": "I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities."
},
{
"instruction": "Identify the odd one out.",
"input": "Twitter, Instagram, Telegram",
"output": "Telegram"
},
{
"instruction": "Explain why the following fraction is equivalent to 1/4",
"input": "4/16",
"output": "The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4."
},
{
"instruction": "Write a short story in third person narration about a protagonist who has to make an important career decision.",
"input": "",
"output": "John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \n\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\n\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him."
},
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finished his meal and left the restaurant",
"output": "He finished his meal and left the restaurant."
},
{
"instruction": "How did Julius Caesar die?",
"input": "",
"output": "Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC."
},
{
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris."
},
{
"instruction": "Generate a list of ten items a person might need for a camping trip",
"input": "",
"output": "1. Tent\n2. Sleeping bags\n3. Flashlight\n4. Matches/lighter\n5. Insect repellent\n6. Extra clothes\n7. First aid kit\n8. Camp stove\n9. Cooking utensils\n10. Biodegradable soaps/shampoos"
},
{
"instruction": "Discuss the causes of the Great Depression",
"input": "",
"output": "The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade."
},
{
"instruction": "Classify the following into animals, plants, and minerals",
"input": "Oak tree, copper or, elephant",
"output": "Oak tree: Plant\n Copper or: Mineral\n Elephant: Animal"
},
{
"instruction": "Explain the use of word embeddings in Natural Language Processing",
"input": "",
"output": "Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks."
},
{
"instruction": "Describe the function of a computer motherboard",
"input": "",
"output": "A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types."
},
{
"instruction": "Reverse engineer this code to create a new version",
"input": "def factorialize(num):\n factorial = 1\n for i in range(1, num):\n factorial *= i\n \n return factorial",
"output": "def factorialize(num):\n factorial = 1\n for i in range(num, 0, -1):\n factorial *= i\n \n return factorial"
},
{
"instruction": "Propose an ethical solution to the problem of data privacy",
"input": "",
"output": "One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data."
},
{
"instruction": "Generate three verbs that mean the same as \"to apologize\"",
"input": "",
"output": "Acknowledge, express regret, make amends."
},
{
"instruction": "What does DNA stand for?",
"input": "",
"output": "DNA stands for deoxyribonucleic acid."
},
{
"instruction": "Compare and contrast the Cuban Missile Crisis and the Vietnam War.",
"input": "",
"output": "The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force."
},
{
"instruction": "Generate a list of random words.",
"input": "",
"output": "sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic."
},
{
"instruction": "Transcribe the recording into text.",
"input": "Recording about the history of Waterloo",
"output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years."
}
]
================================================
FILE: tests/data/_fixtures/dolly.json
================================================
[
{
"instruction": "When did Virgin Australia start operating?",
"context": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
"response": "Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.",
"category": "closed_qa"
},
{ "instruction": "Which is a species of fish? Tope or Rope", "context": "", "response": "Tope", "category": "classification" },
{
"instruction": "Why can camels survive for long without water?",
"context": "",
"response": "Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.",
"category": "open_qa"
},
{
"instruction": "Alice's parents have three daughters: Amy, Jessy, and what\u2019s the name of the third daughter?",
"context": "",
"response": "The name of the third daughter is Alice",
"category": "open_qa"
},
{
"instruction": "When was Tomoaki Komorida born?",
"context": "Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.",
"response": "Tomoaki Komorida was born on July 10,1981.",
"category": "closed_qa"
},
{
"instruction": "If I have more pieces at the time of stalemate, have I won?",
"context": "Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.",
"response": "No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position",
"category": "information_extraction"
},
{
"instruction": "Given a reference text about Lollapalooza, where does it take place, who started it and what is it?",
"context": "Lollapalooza /\u02ccl\u0252l\u0259p\u0259\u02c8lu\u02d0z\u0259/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.\n\nLollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addiction.",
"response": "Lollapalooze is an annual musical festival held in Grant Park in Chicago, Illinois. It was started in 1991 as a farewell tour by Perry Farrell, single of the group Jane's Addiction. The festival includes an array of musical genres including alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. The festivals welcomes an estimated 400,000 people each year and sells out annually. Some notable headliners include: the Red Hot Chili Peppers, Chance the Rapper, Metallica, and Lady Gauge. Lollapalooza is one of the largest and most iconic festivals in the world and a staple of Chicago.",
"category": "closed_qa"
},
{
"instruction": "Who gave the UN the land in NY to build their HQ",
"context": "",
"response": "John D Rockerfeller",
"category": "open_qa"
},
{
"instruction": "Why mobile is bad for human",
"context": "",
"response": "We are always engaged one phone which is not good.",
"category": "brainstorming"
},
{
"instruction": "Who was John Moses Browning?",
"context": "John Moses Browning (January 23, 1855 \u2013 November 26, 1926) was an American firearm designer who developed many varieties of military and civilian firearms, cartridges, and gun mechanisms \u2013 many of which are still in use around the world. He made his first firearm at age 13 in his father's gun shop and was awarded the first of his 128 firearm patents on October 7, 1879, at the age of 24. He is regarded as one of the most successful firearms designers of the 19th and 20th centuries and pioneered the development of modern repeating, semi-automatic, and automatic firearms.\n\nBrowning influenced nearly all categories of firearms design, especially the autoloading of ammunition. He invented, or made significant improvements to, single-shot, lever-action, and pump-action rifles and shotguns. He developed the first reliable and compact autoloading pistols by inventing the telescoping bolt, then integrating the bolt and barrel shroud into what is known as the pistol slide. Browning's telescoping bolt design is now found on nearly every modern semi-automatic pistol, as well as several modern fully automatic weapons. He also developed the first gas-operated firearm, the Colt\u2013Browning Model 1895 machine gun \u2013 a system that surpassed mechanical recoil operation to become the standard for most high-power self-loading firearm designs worldwide. He also made significant contributions to automatic cannon development.\n\nBrowning's most successful designs include the M1911 pistol, the water-cooled M1917, the air-cooled M1919, and heavy M2 machine guns, the M1918 Browning Automatic Rifle, and the Browning Auto-5 \u2013 the first semi-automatic shotgun. Some of these arms are still manufactured, often with only minor changes in detail and cosmetics to those assembled by Browning or his licensees. The Browning-designed M1911 and Hi-Power are some of the most copied firearms in the world.",
"response": "John Moses Browning is one of the most well-known designer of modern firearms. He started building firearms in his father's shop at the age of 13, and was awarded his first patent when he was 24.\n\nHe designed the first reliable automatic pistol, and the first gas-operated firearm, as well inventing or improving single-shot, lever-action, and pump-action rifles and shotguns.\n\nToday, he is most well-known for the M1911 pistol, the Browning Automatic Rifle, and the Auto-5 shotgun, all of which are in still in current production in either their original design, or with minor changes. His M1911 and Hi-Power pistols designs are some of the most reproduced firearms in the world today.",
"category": "information_extraction"
},
{
"instruction": "Who is Thomas Jefferson?",
"context": "Thomas Jefferson (April 13, 1743 \u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.",
"response": "Thomas Jefferson (April 13, 1743 \u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.\n\nAmong the nation's Founding Fathers, Jefferson is considered unmatched in his intellectual depth and breadth. His passionate writings and advocacy for human rights, including freedom of thought, speech, and religion, were a leading inspiration behind the American Revolution, which ultimately gave rise to the American Revolutionary War, American independence, and the United States Constitution. Jefferson's ideas were globally influential in shaping and inspiring the Age of Enlightenment, which proved transformational in the late 17th and 18th centuries. He was a leading proponent of democracy, republicanism, and individual rights, and produced formative documents and decisions at the state, national, and international levels.\n\nDuring the American Revolution, Jefferson represented Virginia in the Second Continental Congress in Philadelphia, which adopted the Declaration of Independence on July 4, 1776. As a Virginia legislator, he drafted a state law for religious freedom. He served as the second Governor of Virginia from 1779 to 1781, during the Revolutionary War. In 1785, Jefferson was appointed the United States Minister to France, and subsequently, the nation's first secretary of state under President George Washington from 1790 to 1793. Jefferson and James Madison organized the Democratic-Republican Party to oppose the Federalist Party during the formation of the First Party System. With Madison, he anonymously wrote the Kentucky and Virginia Resolutions in 1798 and 1799, which sought to strengthen states' rights by nullifying the federal Alien and Sedition Acts.\n\nJefferson and Federalist John Adams became friends as well as political rivals, serving in the Continental Congress and drafting the Declaration of Independence together. In the 1796 presidential election between the two, Jefferson came in second, which according to electoral procedure at the time, made him vice president to Adams. Jefferson challenged Adams again in 1800 and won the presidency. After his term in office, Jefferson eventually reconciled with Adams and they shared a correspondence that lasted 14 years. He and Adams both died on the same day, July 4, 1826, which was also the 50th anniversary of Declaration of Independence.\n\nAs president, Jefferson pursued the nation's shipping and trade interests against Barbary pirates and aggressive British trade policies. Starting in 1803, he promoted a western expansionist policy with the Louisiana Purchase, which doubled the nation's claimed land area. To make room for settlement, Jefferson began the process of Indian tribal removal from the newly acquired territory. As a result of peace negotiations with France, his administration reduced military forces. He was re-elected in 1804, but his second term was beset with difficulties at home, including the trial of former vice president Aaron Burr. In 1807, American foreign trade was diminished when Jefferson implemented the Embargo Act in response to British threats to U.S. shipping. The same year, Jefferson signed the Act Prohibiting Importation of Slaves.\n\nJefferson was a plantation owner, lawyer, and politician, and mastered many disciplines including surveying, mathematics, horticulture, and mechanics. He was also an architect in the Palladian tradition. Jefferson's keen interest in religion and philosophy led to his appointment as president of the American Philosophical Society. He largely shunned organized religion but was influenced by Christianity, Epicureanism, and deism. Jefferson rejected fundamental Christianity, denying Christ's divinity. A philologist, Jefferson knew several languages. He was a prolific letter writer and corresponded with many prominent people, including Edward Carrington, John Taylor of Caroline, and James Madison. In 1785, Jefferson authored Notes on the State of Virginia, considered perhaps the most important American book published before 1800. Jefferson championed the ideals, values, and teachings of the Enlightenment.\n\nSince the 1790s, Jefferson was rumored to have had children by his sister-in-law and slave Sally Hemings, leading to what is known as the Jefferson-Hemings controversy. A 1998 DNA test concluded that one of Sally Hemings's children, Eston Hemings, was of the Jefferson male line. According to scholarly consensus, based on documentary and statistical evaluation, as well as oral history, Jefferson probably fathered at least six children with Hemings, including four that survived to adulthood.\n\nAfter retiring from public office, Jefferson founded the University of Virginia. Presidential scholars and historians generally praise Jefferson's public achievements, including his advocacy of religious freedom and tolerance in Virginia, his peaceful acquisition of the Louisiana Territory from France without war or controversy, and his ambitious and successful Lewis and Clark Expedition. Some modern historians are critical of Jefferson's personal involvement with slavery. Jefferson is consistently ranked among the top ten presidents of American history.",
"category": "information_extraction"
},
{
"instruction": "Who was Kyle Van Zyl playing against when he scored 36 of hisa teams 61 points?",
"context": "Van Zyl joined the Eastern Province Kings Academy, where he played for the Eastern Province U19 side in the 2010 Under-19 Provincial Championship. He was a key player for the Eastern Province U21 side in the 2012 Under-21 Provincial Championship, scoring 71 points in eight appearances. Van Zyl was under the Top SARU Performers, scoring the most tries at 6 in the 2012 Provincial Under 21 in the Rugby Junior Provincials.\n\nThis included a record and a remarkable personal haul in their opening match, when he scored 36 of his team's points in a 61\u20133 victory over Boland U21, consisting of four tries and eight conversions and was awarded Man of the Match.",
"response": "Kyle Van Zyl was playing against Boland U21 when he scored 36 points, leading his team to victory in a 61-3 win.",
"category": "closed_qa"
}
]
================================================
FILE: tests/data/_fixtures/longform_train.json
================================================
[
{
"input": "What are the positions held by Beto O'Rourke, Lupe Valdez, and Veronica Escobar on decriminalizing unauthorized border crossings in the US?",
"output": "The liberal mainstream media squealed with delight last week when they announced Ted Cruz was only 4 points ahead of Democrat Beto O\u2019Rourke.\nThe the bottom dropped out on Beto\u2019s campaign.\nFirst O\u2019Rourke told a liberal Texas crowd that there was nothing more American than kneeling for the National Anthem.\nBeto O\u2019Rourke is now pushing to legalize unauthorized border crossers into America.\nA trio of Democrats running for top offices in Texas want to decriminalize unauthorized border crossings.\nU.S. Rep. Beto O\u2019Rourke, who\u2019s challenging Republican Sen. Ted Cruz for his Senate seat, gubernatorial candidate Lupe Valdez, and congressional candidate Veronica Escobar told HuffPost they are in favor of such a move, which the report says would be simpler to implement than abolishing Immigration and Customs Enforcement, as some Democrats have demanded.\nDecriminalizing unauthorized border crossings would go against the Trump administration\u2019s \u201czero tolerance\u201d policy that led to family separations at the border. The Department of Homeland Security would still be able to charge unauthorized immigrants with civil violations, the report notes, but it would save the federal government billions in incarceration costs.",
"source": "C4",
"subset": "search-engine"
},
{
"input": "For example, for the original quote we shall use, 'One swallow does not a summer make'. If someone were to say 'One good quote does not a philosopher make' (No dig at Aristotle intended...), or 'A fancy suit does not a businessman make'. Essentially, a clear allusion to a quote, but not the quote itself.",
"output": "The word I've heard used for this on the linguistics blog Language Log is snowclone (it's derived from phrases of the format 'If Eskimos have N words for snow, X surely have Y words for Z.') Here's a Language Log post that describes the moment the word was coined and that links to some examples: Snowclones: Lexicographical Dating to the Second And here's how it's defined by Collins Dictionaries online: a verbal formula that is adapted for reuse by changing only a few words so that the allusion to the original phrase remains clear You might also be interested in looking at the answers to the following questions: This book will change your wife trope? Is there a general name for this: 'The more you X the more you'll Y' A term describing the replacement of a specific word in a saying with one that ryhmes",
"source": "Natural Instructions",
"subset": "task225 english language answer generation"
},
{
"input": "Describe the U.S.-Mexico border. Respond in detail.",
"output": "President Donald Trump\u2019s border wall with Mexico will have an initial price tag of between $12 billion and $15 billion, according to Congressional Republican leaders, but the method of paying for Trump\u2019s promised project is still not settled.\nAt the Congressional Republican Retreat in Philadelphia on Thursday, House Speaker Paul Ryan said the wall would cost up to $15 billion, more than the $8 billion Trump has estimated the wall will cost.\nA U.S. Border Patrol agent removes a ladder used by undocumented immigrants to climb a border fence on Oct. 18, 2016 near McAllen, Texas.\nU.S. Customs and Border Protection agents fly near the U.S.-Mexico border while on helicopter patrol on Oct. 18, 2016 near McAllen, Texas.\nA U.S. Border Patrol agent detains a group of undocumented immigrants on Oct. 18, 2016 near McAllen, Texas.\nPeople stand in line to cross legally into the United States from Mexico on Sept. 24, 2016 in Tijuana, Mexico.\nThe Rio Grande forms a stretch of the U.S.-Mexico border on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas. Big Bend is a rugged, vast and remote region along the U.S.-Mexico border and includes the Big Bend National Park.\nU.S. Border Patrol agents with a K-9 unit detain undocumented immigrants after they illegally crossed the U.S.-Mexico border on Oct. 18, 2016, in McAllen, Texas.\nThis photo made with a smart phone through night vision goggles shows the Rio Grande flowing along the U.S.-Mexico border, as seen from a U.S. Customs and Border Protection helicopter during a patrol over the U.S.-Mexico border on Oct. 18, 2016 in McAllen, Texas.\nThe moon rises over the swirling current of the Rio Grande on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas.\nA bullet-proof shield stands to aid U.S. Border Patrol agents on the U.S.-Mexico border on Oct. 3, 2016 in El Paso, Texas.\nA child plays in the Pacific surf near the U.S.-Mexico border fence on Sept. 25, 2016 in Tijuana, Mexico. The nearby Friendship Park is one of the few places on the 2,000-mile border where separated families are allowed to meet.\nDunes stretch into the distance near the U.S.-Mexico border on Sept. 27, 2016 in the Imperial Sand Dunes recreation area, California.\nMexican farm workers hoe a cabbage field on Sept. 27, 2016, in Holtville, Calif. Thousands of Mexican seasonal workers legally cross over daily from Mexicali, Mexico to work the fields of Imperial Valley, Calif., which is some of the most productive farmland in the United States.\nA man looks through the U.S.-Mexico border fence into the United States on Septt. 25, 2016 in Tijuana, Mexico.\nA Studebaker sits outside a cotton field near the U.S.-Mexico border on Oct. 14, 2016 near Fort Hancock, Texas.\nBodies of suspected undocumented immigrants lie in the morgue at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz. Hundreds of migrants die every year in the desert while crossing illegally from Mexico into the United States.\nPersonal effects found on the body of a suspected undocumented immigrant are stored at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz.\nDinner tables await immigrants, many of them deportees from the United States, at the Casa del Migrante Sept. 23, 2016 in Tijuana, Mexico. The shelter, run by Catholic priests, is part of the Coalicion Pro Defensa, which aids immigrants and asylum applicants seeking entry into the United States. The coalition of immigrant advocates has helped thousands of migrants on their journey north and many more after being being deported from the United States.\nMexicans enjoy a late afternoon near the U.S.-Mexico border fence, which ends in the Pacific Ocean, on Sept. 25, 2016 in Tijuana, Mexico. Friendship Park, located on the border between the two countries is one of the few places on the 2,000-mile border where separated families are allowed to meet.\nTrump signed two executive orders on Wednesday directing the construction of the border wall and increasing the number of immigration enforcement officers carrying out deportation, as well as border patrol forces.",
"source": "C4",
"subset": "instruction"
},
{
"input": "Must a Boss use their Level Up ability?\n\nMost Level Up abilities clearly benefit the Boss, but Robobo's is:\n\nEach opponent mush choose and destroy one Room in his or her Dungeon.\n\nThe opponents get to destroy a room (often quite beneficial), and Robobo doesn't!\nIs there any way - short of only building a four room dungeon - that Robobo can avoid using this ability? What play strategies can Robobo use to maximize his Level Up ability?",
"output": "According to the rules, your boss's \"Level Up\" ability is triggered when your dungeon is 5 rooms in size. \"Triggering\" implies necessity. So yes, the only way to avoid levelling up is by not building a fifth room.\nSecondly, and more to your concern... according to the Advanced Rules (see the entry for Robobo in the glossary using the link above), using Robobo's level up does not trigger the \"destroy this Room\" ability for the destroyed room.",
"source": "StackExchange",
"subset": "boardgames"
},
{
"input": "What is Podd\u0119bice?",
"output": "Podd\u0119bice () is a village in the administrative district of Gmina W\u0142oc\u0142awek, within W\u0142oc\u0142awek County, Kuyavian-Pomeranian Voivodeship, in north-central Poland. It lies approximately south of W\u0142oc\u0142awek and south-east of Toru\u0144.",
"source": "Wikipedia",
"subset": "chatbot"
},
{
"input": "Describe the history of MV \"Alexia\" Respond in 4 sentences.",
"output": "MV \"Alexia\" was one of nine Royal Dutch/Shell oil tankers converted to become a Merchant Aircraft Carrier (MAC ship). The group is collectively known as the \"Rapana\" class.\nMV \"Alexia\" was built at Bremer Vulkan and completed in April, 1935 as an oil tanker for the Anglo Saxon Royal Dutch/Shell line. She sustained severe damage in two separate U-boat attacks in 1940 and 1942.",
"source": "Wikipedia",
"subset": "instruction"
},
{
"input": "What is the definition of absurdity?",
"output": "1. the state or quality of being absurd.\nAttending the Gerald R. Ford Symposium on Humor and the Presidency is sort of like attending the Ayatollah Khomeini Symposium on the sexual revolution \u2014Pat Paulsen, at September 19, 1986 symposium in Grand Rapids, Michigan.\n2. absurdity - a ludicrous folly; \"the crowd laughed at the absurdity of the clown's behavior\"\nnoun ridiculousness, nonsense, folly, stupidity, foolishness, silliness, idiocy, irrationality, incongruity, meaninglessness, daftness (informal), senselessness, illogicality, ludicrousness, unreasonableness, preposterousness, farcicality, craziness (informal), b\u00eatise (rare), farcicalness, illogicalness I get angry at the absurdity of a situation.\nfolly, foolery, foolishness, idiocy, imbecility, insanity, lunacy, madness, nonsense, preposterousness, senselessness, silliness, tomfoolery, zaniness.\nBut in the Epic poem the absurdity passes unnoticed.\nFor if absurdity be the subject of laughter, doubt you not but great boldness is seldom without some absurdity.\nI am temperate to the verge of absurdity,\" replied the Tramp.\nPoets, of course, may be satisfactorily read in volumes of, selections; but to me, at least, a book of brief extracts from twenty or a hundred prose authors is an absurdity.\nThen,\" suggested the idea, with a blush for its own absurdity, \"why not go on pilgrimage and seek her?\nJudges and starters have been conveniently blind to this absurdity, but the public demonstration off St.\nAnd now that the providential occurrence was apparently close at hand, it would have been sheer absurdity to think that the supply would be short of the need: as absurd as a faith that believed in half a miracle for want of strength to believe in a whole one.\nutterances, the absurdity being attested by his motley costume.\nOr what greater absurdity can there be than putting before us an old man as a swashbuckler, a young man as a poltroon, a lackey using fine language, a page giving sage advice, a king plying as a porter, a princess who is a kitchen-maid?\nNow to say that the honour I here mean, and which was, I thought, all the honour I could be supposed to mean, will uphold, much less dictate an untruth, is to assert an absurdity too shocking to be conceived.\nHence the absurdity of the interview; the gulf between them was economic as well as spiritual.",
"source": "C4",
"subset": "chatbot"
},
{
"input": "Can felons run for federal office in Minnesota?",
"output": "Minnesota law doesn't block felons from running for federal office.\nEven if enough voters choose Leonard J. Richards as the DFL candidate for U.S. Senate, there\u2019s no way he will ever get to take the oath of office and begin a six-year term.\nRichards is already serving a lifetime term. In Stillwater prison. For murder. Make that two murders.\nYes, it is legal in Minnesota for felons to run for office, so long as it is a federal seat. Nobody knows that better than Richards, who is trying to wrest the party nod away from incumbent Amy Klobuchar.\nNow 75 years old and sporting Department of Corrections ID No. 149837, Richards has run for federal office several times \u2014 without a victory \u2014 since his imprisonment, most recently when he sought the seat that U.S. Rep. Tom Emmer now holds.\nIn 1992, Richards ran in the DFL primary for the Eighth Congressional District seat and received more than 14,500 votes. He ran for the U.S. Senate in the DFL primary in 1994, winning more than 4,000 votes.\nMinnesota law does not permit inmates to run for a state-level office.\nThe official ballot for Minnesota\u2019s primary next month lists double murderer Leonard Richards among those seeking the DFL nomination for U.S. Senate.\nRichards was convicted of murder twice in Hennepin County for the 1982 slaying of his half-sister, May Wilson, and the 1987 shooting death of his attorney, Robert Stratton. His life sentence imposed nearly 30 years ago offers no parole.\nRichards did not respond to a message Thursday seeking an interview.",
"source": "C4",
"subset": "chatbot"
},
{
"input": "What is the purpose of consciousness?",
"output": "Cardiff University and University College London provide funding as founding partners of The Conversation UK.\nMost experts think that consciousness can be divided into two parts: the experience of consciousness (or personal awareness), and the contents of consciousness, which include things such as thoughts, beliefs, sensations, perceptions, intentions, memories and emotions.\nIt\u2019s easy to assume that these contents of consciousness are somehow chosen, caused or controlled by our personal awareness \u2013 after all, thoughts don\u2019t exist until until we think them. But in a new research paper in Frontiers of Psychology, we argue that this is a mistake.\nWe suggest that our personal awareness does not create, cause or choose our beliefs, feelings or perceptions. Instead, the contents of consciousness are generated \u201cbehind the scenes\u201d by fast, efficient, non-conscious systems in our brains. All this happens without any interference from our personal awareness, which sits passively in the passenger seat while these processes occur.\nPut simply, we don\u2019t consciously choose our thoughts or our feelings \u2013 we become aware of them.\nIf this sounds strange, consider how effortlessly we regain consciousness each morning after losing it the night before; how thoughts and emotions \u2013 welcome or otherwise \u2013 arrive already formed in our minds; how the colours and shapes we see are constructed into meaningful objects or memorable faces without any effort or input from our conscious mind.\nConsider that all the neuropsychological processes responsible for moving your body or using words to form sentences take place without involving your personal awareness. We believe that the processes responsible for generating the contents of consciousness do the same.\nOur thinking has been influenced by research into neuropsychological and neuropsychiatric disorders, as well as more recent cognitive neuroscience studies using hypnosis. The studies using hypnosis show that a person\u2019s mood, thoughts and perceptions can be profoundly altered by suggestion.\nIn such studies, participants go through a hypnosis induction procedure, to help them to enter a mentally focused and absorbed state. Then, suggestions are made to change their perceptions and experiences.\nFor example, in one study, researchers recorded the brain activity of participants when they raised their arm intentionally, when it was lifted by a pulley, and when it moved in response to a hypnotic suggestion that it was being lifted by a pulley.\nSimilar areas of the brain were active during the involuntary and the suggested \u201calien\u201d movement, while brain activity for the intentional action was different. So, hypnotic suggestion can be seen as a means of communicating an idea or belief that, when accepted, has the power to alter a person\u2019s perceptions or behaviour.\nAll this may leave one wondering where our thoughts, emotions and perceptions actually come from. We argue that the contents of consciousness are a subset of the experiences, emotions, thoughts and beliefs that are generated by non-conscious processes within our brains.\nThis subset takes the form of a personal narrative, which is constantly being updated. The personal narrative exists in parallel with our personal awareness, but the latter has no influence over the former.\nThe personal narrative is important because it provides information to be stored in your autobiographical memory (the story you tell yourself, about yourself), and gives human beings a way of communicating the things we have perceived and experienced to others.\nThis, in turn, allows us to generate survival strategies; for example, by learning to predict other people\u2019s behaviour. Interpersonal skills like this underpin the development of social and cultural structures, which have promoted the survival of human kind for millennia.\nSo, we argue that it is the ability to communicate the contents of one\u2019s personal narrative \u2013\u2013 and not personal awareness \u2013 that gives humans their unique evolutionary advantage.\nIf the experience of consciousness does not confer any particular advantage, it\u2019s not clear what its purpose is. But as a passive accompaniment to non-conscious processes, we don\u2019t think that the phenomenon of personal awareness has a purpose, in much the same way that rainbows do not. Rainbows simply result from the reflection, refraction and dispersion of sunlight through water droplets \u2013 none of which serves any particular purpose.\nOur conclusions also raise questions about the notions of free will and personal responsibility. If our personal awareness does not control the contents of the personal narrative which reflects our thoughts, feelings, emotions, actions and decisions, then perhaps we should not be held responsible for them.\nIn response to this, we argue that free will and personal responsibility are notions that have been constructed by society. As such, they are built into the way we see and understand ourselves as individuals, and as a species. Because of this, they are represented within the non-conscious processes that create our personal narratives, and in the way we communicate those narratives to others.\nJust because consciousness has been placed in the passenger seat, does not mean we need to dispense with important everyday notions such as free will and personal responsibility. In fact, they are embedded in the workings of our non-conscious brain systems. They have a powerful purpose in society and have a deep impact on the way we understand ourselves.",
"source": "C4",
"subset": "chatbot"
},
{
"input": "What is Nemaha, IA? Respond briefly.",
"output": "Nemaha is a city in Sac County, Iowa, United States. The population was 85 at the 2010 census.\nAccording to the United States Census Bureau, the city has a total area of , all of it land.",
"source": "Wikipedia",
"subset": "chatbot"
},
{
"input": "In the book I am reading (Patrick Fairbairn, The Interpretation of Prophecy, 2nd ed., 1865), the abbreviation 'sq.' is used sometimes in referencing both Scripture and other books (as 'Typology of Scripture, vol. i. p. 100 sq.' or 'Ezek xxvi. 7, sq.'). What does 'sq.' mean in this context? (Instances of this abbreviation's usage may be found at Google Books here.).",
"output": "Sq. is short for the Latin word sequiturque, which means the following. So p. 100 sq. means page 100 and those following pages. Here's the source.",
"source": "Natural Instructions",
"subset": "task225 english language answer generation"
},
{
"input": "\"Describe Brugnens.\" Respond in 11 words.",
"output": " Brugnens is a commune in the Gers department in southwestern France.",
"source": "Wikipedia",
"subset": "instruction"
},
{
"input": "What is the DeLand Municipal Airport? Respond in 1 sentences.",
"output": "DeLand Municipal Airport , also known as Sidney H. Taylor Field, is a city-owned public-use airport located three nautical miles (5\u00a0km) northeast of the central business district of DeLand, a city in Volusia County, Florida, United States.",
"source": "Wikipedia",
"subset": "chatbot"
},
{
"input": "Respond in 203 words. Describe the plot of the 1976 Filipino drama film \"Minsa'y Isang Gamu-gamo\".",
"output": "Minsa'y Isang Gamu-gamo (\"\") is a 1976 Filipino drama film directed by Lupita Aquino-Kashiwahara and written by Marina Feleo-Gonzales. It concerns a Filipino nurse, Cora de la Cruz, who dreams of moving to America. When her brother is killed, her ideas change. The film criticises American military presence in the Philippines. It was an entry to the 1976 Metro Manila Film Festival and won five awards at the 25th FAMAS Awards, including those for best picture, director and screenplay. The film was restored by the ABS-CBN Film Restoration through Kantana Post Production in Thailand and the restored version was released on October 12, 2018 as part of Cinema One Originals film festival.\nThe de la Cruises and the Santos are two lower-middle-class families who live in Pampanga. Cora de la Cruz is a nurse who dreams of living in the United States. Her papers are ready and she organises employment at an American hospital. She hopes to get a green card allowing her to stay, achieve immigrant status and then bring her family to America for a better life. Cora's mother and younger brother encourage her but her father, Ingkong, disagrees. He believes that moving to America is a betrayal of their country.",
"source": "Wikipedia",
"subset": "instruction"
},
{
"input": "Hey, when does Marvel's Runaways Season 2 start?",
"output": "The second season picks up where the first left off.\nHulu has announced that the second season of Marvel's Runaways will premiere December 21st, Deadline reports. And there won't be any weekly installments this time around as there were with the first season because all 13 episodes are getting dropped at once. Hulu renewed the series for a second season just as the first one wrapped up.\nSeason 2 will kick off right where the first ended -- with the kids having escaped their evil parents and setting out their own. Now they'll have to figure out how to live with each other, how to take care of themselves and how to take down their parents' nefarious Pride squad. But they might also have to sniff out a mole since someone's sending the Pride leader mysterious messages. Meanwhile, the Pride is bound and determined to get them back.",
"source": "C4",
"subset": "chatbot"
},
{
"input": "I need to serve at a dinner party, but I'm not sure where to start. Can you give me 10 actionable steps?",
"output": "Which way should you pass the plates? How should you clear the table? Serving at a dinner party isn't exactly as easy as pie. Here are some basic guidelines to help you out in your next dinner party.\n1. Steamed or roasted vegetables, rice, salad, potatoes, and condiments are some of the things you can usually leave in the middle of the table for guests to help themselves to. This isn't strictly necessary, though. If you prefer to keep things simple, you don't have to leave anything more than salt and pepper on the table.\n2. This is known as \"plating-up.\" Don't expect guests to help themselves to food that requires assembly or is difficult to manage. The only way they'll get to see your creativity is if you put it out on the plate for them before you even get to the dinner table.\nKeep the rim of the plate clear of sauces, spills, anything; it frames the food.\n3. You can do this if you want to keep with tradition and the occasion is quite formal. Otherwise, choose one end of the table and move around it in a clockwise fashion, regardless of the genders of your guests.\n\nPass all dishes from the left. Guests and servers should pass dishes from left. The logic behind this is that most people are right-handed and this allows them them to serve themselves from the dish while it is being held by the passer. Nowadays it is less likely that the passer will continue to hold the dish, but will expect you to take it, so if you are left-handed, it won't be a problem. Place the dish down on your side plate to serve from it.\n\nAs the cook, or host, always serve yourself last. This is polite and also sensible, since you'll probably be busy anyway with host's duties.\n4. They'll get fidgety, anxious and gossipy about what you're doing.\n5. On the other hand, do not ever go into details about how the flesh portion of the meal was hunted/killed. This is bad taste and makes some guests very queasy. Leave it for discussion around the fireplace with a like-minded friend after dinner.\n6. The host or hired help should clear no more than two plates at a time to avoid bumping guests and interfering with their eating. There is nothing more annoying than the server's elbows in your face when you're just about to take the next bite.\n7. Preferably the noises should not reach the guests but this is unrealistic for most homes. Just do it as quietly as possible and try not to clank, crack, break or drop the dishes. The last thing you need on top of anything else is a dropped plate to clean up.\n8. This means all the dishes on the table, the condiments and the side plates. If you haven't already set out the dessert spoons, this is the time to do so.\n9. Chocolates will have their own method of getting around the table; it has been suspected they have legs...\n10. Don't take advantage of the situation because that guest is there to enjoy himself too but don't hesitate to ask for a quick hand with a simple task that won't risk spills on their clothes.\n",
"source": "WikiHow",
"subset": "main"
},
{
"input": "What can you tell me about Brushcreek Township? Respond in 2 sentences.",
"output": "Brushcreek Township is one of the seventeen townships of Highland County, Ohio, United States. As of the 2010 census the population was 1,381, of whom 1,248 lived in the unincorporated portion of the township.",
"source": "Wikipedia",
"subset": "chatbot"
}
]
================================================
FILE: tests/data/_fixtures/longform_val.json
================================================
[
{
"input": "The Big Mistake\n\nThis day was full of joy and happiness, but something went wrong after when she told him that she is pregnant. \nShe was so happy about this news and she thought that he will be happy as well, but... At the moment when she told him that she is pregnant, he just frozen. It was too much for him.\nHe thought it's easier to walk away and avoid this conversation with her. He tried to understand what's going on with him. It was weird how many questions and doubts came up on his head. One part of him was really happy, because to be a father is his one of the biggest dream, but on the other hand he couldn't handle this news. \nThis woman, who stands next to him, is his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loves him as much as he loves her, but everything was just too wrong on this moment. Her voice, her smile and even her body was to wrong on this moment.\nShe couldn't understand what\u2019s going on, because it was just three weeks ago when they had a conversation about kids and family and he said that he would like to have kids with her. And now, he just avoids her and he is not happy about the news. \nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they love each other and they care about each other, but it was too much for him. He stood up and said to her girlfriend that he needs a break and he have to be alone. Michael closed the door and knew at that moment he had made a mistake.\n\nFix all grammar errors and style issues and rewrite this document",
"output": "The Big Mistake\n\nThis day was full of joy and happiness, but something went wrong when she told him that she was pregnant. \nShe was so happy about this news and she thought that he would be happy as well, but... At the moment when she told him that she was pregnant, he just frozen. It was too much for him.\nHe thought it was easier to walk away and avoid this conversation with her. He tried to understand what was going on with him. It was weird how many questions and doubts came into his head. One part of him was really happy, because to be a father was one of his biggest dreams, but on the other hand, he couldn't handle this news. \nThis woman, who was standing next to him, was his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loved him as much as he loved her, but everything was just all wrong at this moment. Her voice, her smile and even her body was wrong on this moment.\nShe couldn't understand what was going on, because it was just three weeks ago that they had a conversation about kids and family and he said that he would like to have kids with her. And now, he was just avoiding her and he was not happy about the news. \nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they loved each other and they cared about each other, but it was too much for him. He stood up and said to his girlfriend that he needed a break and he had to be alone. Michael closed the door and knew at that moment he had made a mistake.\n",
"source": "BEA-GEC",
"subset": "main"
},
{
"input": "Although the fact that personal vehicle like car now is popular, in the future, the stance of public transport is still very crucial. \n\nNowaday, people use fossil fuel as the most important source of energy for travelling, manufacturing, this source of energy doesn't always maintain at the constant level, but someday it will be run out. People must find ways to use other energy sources like nuclear electricity, wind power, public transportation such as electric metro use this kind of energy. As a result, governments will invest more in researching the usage of these new types of energy, travelling using public transport will become much cheaper than personal car. The range of public services will be expanded to remote areas, it become much more convenient. Commuting, travelling using public services will be the first prior choice.\n\nIn addition, the population is increasing rapidly, in the near future, almost cities in the world will become overcrowded, It will lead to many other problems to be solved, for example, traffic jams, air pollution. In some perespectives, these problems are led by the over-using of personal vehicle. People will realise that using their own car is just to waste their time for traffic jams or to gradually killing themselves for breathing polluted air. So the best way for them to travel is public transpotations. \nEach person should practice saving when using any source of eneny to protect his own live. \nIn conclusion, investments on developing public transport will be increased considerably, public services of travelling have a bright potential and it's stance in the future can't be replaceble\n\nTweak this document to make it read smoother and fix any mistakes you see.",
"output": "Despite the fact that personal vehicles like the car are now popular, in the future, the existence of public transport will still absolutelyery crucial. \n\nNowadays, people use fossil fuels as the most important source of energy for travelling and manufacturing. This source of energy doesn't always remain at a constant level, and someday it will run out. People must find ways to use other energy sources like nuclear power and wind power. Public transportation such as the electric metro use this kind of energy. As a result, governments will invest more in researching the use of these new types of energy, and travelling using public transport will become much cheaper than privatepersonal car. The range of public services will be expanded to remote areas, and it will become much more convenient. Commuting and travelling using public services will be the first priority.\n\nIn addition, the population is increasing rapidly. In the near future, most cities in the world will become overcrowded. This will lead to many other problems that will need to be solved. For example, traffic jams and air pollution. In some ways, these problems are caused by the overuse of personal vehicles. People will realise that using their own car is just wasting their time in traffic jams or gradually killing themselves by breathing polluted air. So the best way for them to travel is public transport. \nEach person should practice saving energy when using any source of energy to protect his own life. \nIn conclusion, investments in developing public transport will be increased considerably. Public transport services have a bright future and their existence in the future can't be replaced.\n",
"source": "BEA-GEC",
"subset": "main"
},
{
"input": "What can you tell me about Nang Ka Teih? Respond in 58 words.",
"output": "Founded by the people of Chin (Zo), Nang Ka Teih is a town ward within Sagain Division or next to Khampat town in the Myanmar state of Chin and Sagain Division. It has a very high concentration of the religion. Nang Ka Teih is one of the best location in Chin State and listed among the peace area.",
"source": "Wikipedia",
"subset": "chatbot"
},
{
"input": "I'm having a serious argument with a friend on the status of the word 'suck' when I used it about him by saying 'You suck!' because he missed a train. We are both non-native English speakers. He claimed that I used it as a foul or slang term. I vehemently disagreed with him about it. I told him that the word 'suck' can be used to describe something inefficient or not good enough as well like it is used in the example 'Samsung mobile sucks'. Am I right to describe the use of the word 'suck' in such context?.",
"output": "Yes, you can use the word in the way that you have described, but it's considered more harsh than polite, and it has somewhat vulgar overtones. How it's regarded or received might be generational. I typed is suck vulgar? on Google, and found mixed responses. Feel free to do the same if you want diverse opinions on the matter. I thought this excerpt from a blog post, though, was worth pasting into an answer here: Some may not believe this, but suck as in 'Man, this class sucks' was also in the raw obscenity category when I was a teenager. It was used plenty in the school hallways but not in front of your teacher and never in front of your mother. I remember some agitation by certain culturally-advanced youngsters who tried to railroad their elders into accepting sucks as a safe and harmless substitute for stinks. The elders weren't having any of it, last I checked, but the liberalizing linguists seem to have carried the day. I have always assumedrightly or wrongly, I do not know that the word was originally intended to carry sexual overtones, which was the reason for its suppression. Today, the sexual overtones are either forgotten or are now acceptable in mixed company. I'm not sure which explanation disturbs me more. I think you and your friend are unlikely to come up with an agreed-upon viewpoint, because you're both right in a way. Feel free to use it on message boards and the like when you want to express a negative opinion, but realize you'll risk sounding a bit uncouth to some when you do. Then again, maybe I'm just showing my age here. As a footnote, you might want to check out our sister site, English Language Learners.",
"source": "Natural Instructions",
"subset": "task225 english language answer generation"
},
{
"input": "How is BeeHighve Inc. in Corner Brook, NL infusing cannabis with honey products and bringing them to markets, both local and global?",
"output": "BeeHighve CEO Rita Hall intends to bring Newfoundland honey and honey-based products to the market, some infused with cannabis, while others will be \"buzz free.\"\nA selection of some of the products available from Corner Brook, N.L.-based BeeHighve Inc.\nThere's a lot of buzz around a joint venture between two Newfoundland companies that want to bring cannabis-infused honey and honey products to local and global markets.\nBeeHighve Inc., based in Corner Brook, will be cultivating the cannabis crop, while G and M Family Farms, near Placentia, will supply the honey.\n\"It ranges from pure honey to sauces as well as chocolate and health bars, and everything is organic,\" said BeeHighve CEO Rita Hall. \"And everything is very healthy for you. We don't use sugar in our products.\"\nThe partnership is the brainchild of Hall, a trailblazer who is on track to become the first Indigenous woman to gain licensed producer [of marijuana] status in Canada, as well as one of the country's few Indigenous female CEOs.\nHall intends for Newfoundland honey to be the backbone of her operation, and utilizes it in all the products, including the flagship Nuts About Honey bars.\nAlthough no official date has been set for the legalization of cannabis-infused edibles \u2014 with recreational marijuana use legal as of Oct. 17, 2018 \u2014 BeeHighve plans to produce the same line of products, without the cannabis elements before and after the legalization of edibles.\n\"I don't think it's going to have a negative impact on the business at all. The honey is really generating a lot of interest,\" Hall said.\nPart of the interest is because of Newfoundland's uniquely thriving honey-bee population, who enjoy the benefits of a closed ecosystem comparatively free from mites and diseases associated with colony collapse.\n\"They love the idea of mite-free, antibiotic-free honey. So I have no doubt that the production and sale of honey and the consumables without cannabis infusions will go very well.\"\nEventually BeeHighve intends to get into the beekeeping business as well, allowing them to produce larger volumes of product in less time.\nAside from the plans to export the cannabis-infused products \u2014 where legal \u2014 as well as the \"buzz-free\" ones globally, BeeHighve is looking to expand its production to another province.\nPartnering with the Madawaska Maliseet First Nation reserve in New Brunswick, crops will be cultivated on the reserve to later be infused with Newfoundland honey. Hall believes the partnership will be a fruitful one, and has great respect for the Madawaska Maliseet, whose senior leadership is entirely made of women.\n\"It shows the strength of women in any marketplace. Women are really underrepresented in the cannabis industry right now, so it's, I'll say, a feather in our caps, no pun intended, to be a part of this industry as well.\"\nHall has just as much faith in this venture as the one in Newfoundland.\n\"We're very strong women and I think very successful, and we'll succeed at this as well.\"",
"source": "C4",
"subset": "search-engine"
},
{
"input": "Respond briefly. What is the history of the Eagle Ranger Station?",
"output": "The Eagle Ranger Station, also known as the Eagle Guard Station and presently known as the Sol Duc Ranger Station, is a complex of three buildings built in the 1930s in what would become Olympic National Park. The primary structures were built by the U.S. Forest Service in what was at the time the Olympic National Forest., While the main residence was built by the USFS, the generating plant and landscaping were built by the National Park Service using labor provided by the Civilian Conservation Corps.",
"source": "Wikipedia",
"subset": "chatbot"
},
{
"input": "How to deal with non negotiable change\n\nOne of the issues my organisation has is where change requests are requested during a sprint cycle in a particular high volume worksteam. The development team are often forced to accept these change requests, where any resistance results in an escalation, which results with somebody high up in the business just telling the dev team to do it. Thus, team is constantly over-committing every sprint cycle.\nI have highlighted this as an issue to Senior management, showing the impact the additional work is having on morale and quality, but nothing seems to be changing at org level, and they seem to think the problem is at team level. Hence, keep trying to address the issue by adding and removing Scrum Masters; where they all are having the same problem.\nWhat is the best way to address this?\n\nA:\n\nTL;DR\nYou have both a political problem and a process problem. There is no silver bullet, and you can't fix these problems solely within the team. Instead, you must rely on the agile concepts of transparency and visibility to ensure that the costs associated with bypassing the process are charged to the project budget.\nMake Costs Visible\nIn agile frameworks, change is never non-negotiable. If you are faced with \"non-negotiable\" requirements that make the goal or plan for the current iteration obsolete, then the team must scrap the plan.\nThis is true regardless of the agile framework you're using. In Scrum, the Product Owner cancels the Sprint and the Scrum Team returns to Sprint Planning. In Kanban, work is pulled off the board to meet work-in-progress (WIP) limits, and queued in the appropriate bucket until capacity is once again available. Your specific framework may vary in the way it implements this freeing up of capacity, but it's an essential component of any viable agile (and even non-agile) process.\nThe cost of replanning and re-queuing is made visible through metrics such as cycle time, cumulative flow, velocity, burn-down, or other framework- and implementation-specific metrics. Communicating about these costs is also essential. By presenting the data to stakeholders and company officers, you place the responsibility for those costs squarely where it belongs.\nSchedule risk (like any risk) can be accepted, transferred, or mitigated. Which of these three options is best is a business decision, not a team decision. Regardless of the choice, senior management owns both the choice and the outcome.\nSee Also\nhttps://pm.stackexchange.com/a/10829/4271",
"output": "I agree with Todd about it being a political and process problem.\nI want to add that apart from showing them the impact on morale and quality -which I believe that they are quite important for your team- I would recommend having a discussion with them to understand why they believe it is useful to have this way of working and also prepare for them a \"business case\" that explores what is they key impact of this way of working for the senior management (trying to talk their own language basically). \nMost of the times, this comes down to how much money they are losing by having this way of working. Examples of the impact to senior management are:\n\nDelayed features due to the introduction of these changes \nBugs introduced due to quality issues that required to introduce more tasks in future releases.",
"source": "StackExchange",
"subset": "pm"
},
{
"input": "Use of LaTeX commands\n\nIt appears that LaTeX commands cannot be used on SE Linguistics, but they can on other SE sites. Personally, I wanted to use this feature a number of times to insert a formula and here's another question with the same problem. Expressing formulae can be useful when discussing things as simple as the number of combinations that arise from the different conditions in an experiment, or when discussing statistics questions relevant to linguistics. \nCould we please allow the use of LaTeX commands? I don't think having this feature will be a problem for anyone not wanting to use it and it will help those who do.\n\nA:\n\nAbsolutely subscribing here!\nI am surprised that this is not a feature already, I am really missing TeX support - for a wide range of uses that I consider essential in the field of linguistics: \n\nMathematical formulas.\nAs someone who frequently answers questions on formal semantics, like here, here and here, I feel that answering (and asking) these questions in the current state is a mess, to an extent where it has sometimes kept me from writing answers to those questions altogether, simply because the process of setting them up is so annoying.\nWhile I am aware that there are tools which ease the inclusion of Unicode symbols as HTML, all of these solutions (at least the ones I am aware of) still make typing more than one line of math quite cumbersome given that every symbol has to be clicked on individually, but worse than that, many characters and necessary formatting options, like pretty much half of the inventory I needed in this linguistically originated question, aren't even available as predefined Unicode symbols at all. Not to mention the fact that the output is visually not particularly pleasing and often makes the question hard to read, as can well be seen in this post.\nBesides, I doubt that the average user is aware of the existence of such tools; it took me more than a year on SE to find out about these, while new users have repeatedly been asking for LaTeX support, because this what users are used to.\nApart from formulas for questions on logic, MathJax would prove useful in questions involving statistics or combinatorics, which time and time again turn up this site. \nIPA, as already discussed here and here, to frequently be used in posts like here, here, here and many posts by other users.\nAt this point I am not sure whether enabling MathJax will also provide the opportunity to support other LaTeX packages, but if so, being able to use tipa would be of great help in typing IPA symbols.\nI am aware of IPA typeit which I have used so far, and some other external tools, but again, clicking is cumbersome, and users will be more comfortable writing questions and answers with the tools thy are used to (and the standard way of typesetting IPA symbols is LaTeX tipa). \nOther, less urgent cases where I could imagine TeX support to be useful - again under the premise that support for special packages is in principal possible - include\n\ntree structures, most importantly syntax trees, like here, here and here; a similar request was already made on meta here.\nlinguistic example formatting, which I could have used here, and interlinear glosses for posts like here, which are rather inconvenient to do appropriately without automatic alignment.\nThe current solution to these use cases is to typeset them externally and then include them in the post via a screenshot - I could live with this solution for these particular cases, but if MathJax comes, then this would certainly be a nice for these less frequent use cases as well. \n\nA:\n\nIt won't be a problem, but we need to show that the site actually needs it. \nOn Physics SE, the same question has been asked and you can see it got at least 32 votes. If there is not enough support from the community, which means enough people that want it, then it's probably not gonna happen. It's the same as for Migration paths in other sites, you need to prove there is a case for it to be implemented.\nIf it gets enough support, it'll be added. ",
"output": "A clear yes. And do not forget to include tipa support for IPA in LaTeX, I find tipa input from the keyboard very convenient (compared to searching, cutting, and pasting IPA characters from a Unicode character table).",
"source": "StackExchange",
"subset": "linguistics"
},
{
"input": "In this task, you are given text for US Congressional and California state bills, your task is to generate a summary for this bill.\n\"SECTION 1. SHORT TITLE.\n\n This Act may be cited as the ``Sequestration Relief Act of 2013''.\n\nSEC. 2. FINDINGS AND PURPOSE.\n\n (a) Findings.--Congress finds the following:\n (1) Congress must enact a comprehensive, deficit reduction \n plan to solve the country's fiscal challenges and to promote \n national security, economic stability, and the continued growth \n and prosperity of the United States.\n (2) The keys to a comprehensive, deficit reduction solution \n are increased revenues and changes in mandatory spending.\n (3) The Budget Control Act of 2011 was enacted to avert a \n default on Federal debt obligations, and it reduced \n discretionary spending by approximately $1 trillion through \n fiscal year 2021.\n (4) Because the Joint Select Committee on Deficit Reduction \n failed to recommend legislation providing an additional $1.2 \n trillion in deficit reduction, Federal law mandates that the \n additional savings be sequestered.\n (5) Sequestration was designed as a forcing mechanism for \n an agreement on a comprehensive, deficit reduction plan. It has \n failed to produce the intended results.\n (6) It no longer makes sense to rely on sequestration as a \n forcing mechanism for a balanced solution. The costs to our \n government and to the economy are too great.\n (7) Under sequestration, automatic, indiscriminate cuts \n would be applied, through fiscal year 2021, to a wide variety \n of discretionary spending programs to achieve $1.2 trillion in \n savings, forestalling the sound planning needed for prudent and \n meaningful investments in national security, the workforce, \n transportation infrastructure, education, health care, public \n safety, housing, innovation, small business development, and \n many other facets of enduring national strength.\n (8) Even the prospect of sequestration is disruptive to \n regular order and to the congressional appropriations process, \n and it fosters damaging economic uncertainty, while short-term \n solutions only suspend the prospect and continue to undermine \n the certainty needed for economic recovery.\n (9) Therefore, Congress must eliminate the threat of \n sequestration.\n (10) Given the magnitude of the Federal deficit, it is \n likely that additional cuts to discretionary spending will be \n necessary for a comprehensive deficit reduction solution.\n (11) Congress must establish a manageable, long-term \n discretionary spending plan. An additional $320 billion in \n targetable cuts to discretionary appropriations from fiscal \n year 2014 through fiscal year 2021 represents one-third of the \n net amount that would have been indiscriminately cut by \n sequestration over fiscal years 2013 through 2021.\n (12) It is recognized that a reduction of $167 billion to \n discretionary appropriations within budget function 050 from \n fiscal year 2014 through fiscal year 2021 will affect the \n National Military Strategy. The Department of Defense is highly \n encouraged to revisit its current strategic guidance and to \n work closely with Congress in building a new National Military \n Strategy that accounts for available resource levels.\n (b) Purposes.--The purposes of this Act are to--\n (1) eliminate the threat of sequestration to the American \n economy;\n (2) offer the Federal Government, industry, and the \n American people the predictability that economic recovery \n demands;\n (3) enable the Congress to pass appropriations legislation \n in regular order with a clear discretionary spending budget and \n grant the legislative and executive branches of government the \n flexibility needed to identify and implement specific \n discretionary spending reductions in a responsible and \n deliberate manner; and\n (4) provide a practicable, long-term discretionary spending \n plan that will contribute to a comprehensive, balanced, long-\n term, deficit reduction solution that includes affordable \n revisions to mandatory spending and new revenues.\n\nSEC. 3. REPEAL OF SECTION 251A SEQUESTRATIONS.\n\n Section 251A of the Balanced Budget and Emergency Deficit Control \nAct of 1985 is repealed.\n\nSEC. 4. $320 BILLION REDUCTION IN DISCRETIONARY SPENDING LIMITS.\n\n The discretionary spending limits set forth in paragraphs (3) \nthrough (10) of section 251(c) of the Balanced Budget and Emergency \nDeficit Control Act of 1985 are amended to read as follows:\n ``(3) for fiscal year 2014--\n ``(A) for the security category, $546,000,000,000 \n in budget authority; and\n ``(B) for the nonsecurity category, \n $501,000,000,000 in budget authority;\n ``(4) with respect to fiscal year 2015--\n ``(A) for the security category, $550,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $505,000,000,000 in new budget authority;\n ``(5) with respect to fiscal year 2016--\n ``(A) for the security category, $559,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $513,000,000,000 in new budget authority;\n ``(6) with respect to fiscal year 2017--\n ``(A) for the security category, $569,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $522,000,000,000 in new budget authority;\n ``(7) with respect to fiscal year 2018--\n ``(A) for the security category, $579,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $531,000,000,000 in new budget authority;\n ``(8) with respect to fiscal year 2019--\n ``(A) for the security category, $589,500,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $541,000,000,000 in new budget authority;\n ``(9) with respect to fiscal year 2020--\n ``(A) for the security category, $602,500,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $553,000,000,000 in new budget authority;\n ``(10) with respect to fiscal year 2021--\n ``(A) for the security category, $616,000,000,000 \n in new budget authority; and\n ``(B) for the nonsecurity category, \n $565,000,000,000 in new budget authority;''.\n\nSEC. 5. DEFINITION OF SECURITY CATEGORY.\n\n Section 250(c)(4)(B) of the Balanced Budget and Emergency Deficit \nControl Act of 1985 is amended to read as follows:\n ``(B)(i) For fiscal years 2012 and 2013, the term `security \n category' means discretionary appropriations associated with \n agency budgets for the Department of Defense, the Department of \n Homeland Security, the Department of Veterans Affairs, the \n National Nuclear Security Administration, the intelligence \n community management account (95-0401-0-1-054), and all budget \n accounts in budget function 150 (international affairs).\n ``(ii) For fiscal years 2014 through 2021, the term \n `security category' means discretionary appropriations in \n budget function 050 (national defense).''.\n\nSEC. 6. SUSPENSION OF STATUTORY LIMIT ON THE PUBLIC DEBT UNTIL FEBRUARY \n 1, 2017.\n\n Section 2 of the No Budget, No Pay Act of 2013 is amended--\n (1) in subsection (a), by striking ``May 18, 2013'' and \n inserting ``January 31, 2017''; and\n (2) in subsection (b), by striking ``May 19, 2013'' each \n place it appears and inserting ``February 1, 2017''.\"",
"output": "Sequestration Relief Act of 2013 - Amends the Balanced Budget and Emergency Deficit Controlled Act of 1985 (Gramm-Rudman-Hollings Act), as amended by the Budget Control Act of 2011, to repeal automatic sequestration to enforce a specified budget goal. Reduces the discretionary spending limits for security and nonsecurity categories by $320 billion for FY2014-FY2021. Confines the current meaning of "security category" under the Gramm-Rudman-Hollings Act to FY2012-FY2013. Revises the definition of "security category" for FY2014-FY2021 to mean discretionary appropriations in all of budget function 050 (national defense). Amends the No Budget, No Pay Act of 2013 to suspend through January 31, 2017, the current $16.394 trillion public debt limit. Postpones until February 1, 2017, an automatic increase in the public debt limit to the extent that: (1) the face amount of obligations issued and the face amount of obligations whose principal and interest are guaranteed by the federal government (except guaranteed obligations held by the Secretary of the Treasury) outstanding on February 1, 2017, exceeds (2) the face amount of such obligations outstanding on the date of enactment of the No Budget, No Pay Act of 2013 (February 24, 2013). Prohibits an obligation from being taken into account unless its issuance was necessary to fund a commitment incurred by the federal government that required payment before February 1, 2017.",
"source": "Natural Instructions",
"subset": "task1658 billsum summarization"
}
]
================================================
FILE: tests/data/test_alpaca.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from litgpt.data import Alpaca
from litgpt.prompts import Alpaca as AlpacaPromptStyle
def test_alpaca(mock_tokenizer, alpaca_path):
alpaca = Alpaca(val_split_fraction=0.5, download_dir=alpaca_path.parent, file_name=alpaca_path.name, num_workers=0)
assert isinstance(alpaca.prompt_style, AlpacaPromptStyle)
alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
alpaca.prepare_data()
alpaca.setup()
train_dataloader = alpaca.train_dataloader()
val_dataloader = alpaca.val_dataloader()
assert len(train_dataloader) == 6
assert len(val_dataloader) == 6
train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))
assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"}
for key in ["input_ids", "labels"]:
assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]"
assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]"
assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)
# has attributes from super class `LightningDataModule`
assert alpaca.prepare_data_per_node
================================================
FILE: tests/data/test_base.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from typing import Optional
import pytest
import torch
from litgpt.data.base import SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
@pytest.mark.parametrize("mask_prompt", [True, False])
@pytest.mark.parametrize("ignore_index", [-1, -100])
@pytest.mark.parametrize("max_seq_length", [1000, 5, -1])
def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer):
class Style(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:
return f"In: {prompt} Out:"
i = ignore_index
data = [{"instruction": "Foo", "output": "Bar"}, {"instruction": "Boo", "output": "Ahh"}]
dataset = SFTDataset(
data=data,
tokenizer=mock_tokenizer,
prompt_style=Style(),
mask_prompt=mask_prompt,
ignore_index=ignore_index,
max_seq_length=max_seq_length,
)
assert len(dataset) == len(data)
expected_input_ids = torch.tensor([73, 110, 58, 32, 70, 111, 111, 32, 79, 117, 116, 58, 66, 97, 114, 1])
# If prompt is not masked, labels == input_ids
expected_labels = (
torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids
)
if max_seq_length == -1:
assert torch.equal(dataset[0]["input_ids"], expected_input_ids)
assert torch.equal(dataset[0]["labels"], expected_labels)
else:
assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length])
assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length])
@pytest.mark.parametrize("ignore_index", [-1, -100])
@pytest.mark.parametrize("pad_id", [0, 100])
def test_sft_collate_fn_padding(pad_id, ignore_index):
collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index)
samples = [
{
"input_ids": torch.tensor([1, 2, 3]),
"labels": torch.tensor([10, 20, 30]),
"token_counts": {"raw": 3, "raw_plus_prompt_template": 25},
},
{
"input_ids": torch.tensor([4, 5, 6, 7, 8]),
"labels": torch.tensor([40, 50, 60, 70, 80]),
"token_counts": {"raw": 5, "raw_plus_prompt_template": 27},
},
]
expected = {
"input_ids": torch.tensor([[1, 2, 3, pad_id, pad_id], [4, 5, 6, 7, 8]]),
"labels": torch.tensor([[10, 20, 30, ignore_index, ignore_index], [40, 50, 60, 70, 80]]),
"token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])},
}
batch = collate(samples)
assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels"))
for key in ("raw", "raw_plus_prompt_template"):
assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}"
def test_sft_collate_fn_truncation():
collate = get_sft_collate_fn(max_seq_length=2)
samples = [
{
"input_ids": torch.tensor([1, 2, 3]),
"labels": torch.tensor([10, 20, 30]),
"token_counts": {"raw": 3, "raw_plus_prompt_template": 25},
},
{
"input_ids": torch.tensor([4, 5, 6, 7, 8]),
"labels": torch.tensor([40, 50, 60, 70, 80]),
"token_counts": {"raw": 5, "raw_plus_prompt_template": 27},
},
]
expected = {
"input_ids": torch.tensor([[1, 2], [4, 5]]),
"labels": torch.tensor([[10, 20], [40, 50]]),
"token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])},
}
batch = collate(samples)
assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels"))
for key in ("raw", "raw_plus_prompt_template"):
assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}"
================================================
FILE: tests/data/test_deita.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from unittest import mock
from litgpt.data import Deita, SFTDataset
from litgpt.data.deita import format_dataset
from litgpt.prompts import Alpaca as AlpacaPromptStyle
def test_format_dataset():
data = [
{
"prompt": "prompt1",
"prompt_id": "1",
"messages": [
{"content": "question1", "role": "user"},
{"content": "response1", "role": "assistant"},
{"content": "question2", "role": "user"},
{"content": "response2", "role": "assistant"},
],
},
{
"prompt": "prompt2",
"prompt_id": "2",
"messages": [
{"content": "question3", "role": "user"},
{"content": "response3", "role": "assistant"},
{"content": "question4", "role": "user"},
{"content": "response4", "role": "assistant"},
],
},
]
assert format_dataset(data, include_multi_turn_conversations=False) == [
{"instruction": "question1", "output": "response1", "input": ""},
{"instruction": "question3", "output": "response3", "input": ""},
]
assert format_dataset(data, include_multi_turn_conversations=True) == [
{"instruction": "question1", "output": "response1", "input": ""},
{"instruction": "question2", "output": "response2", "input": ""},
{"instruction": "question3", "output": "response3", "input": ""},
{"instruction": "question4", "output": "response4", "input": ""},
]
@mock.patch("litgpt.data.deita.format_dataset")
@mock.patch("datasets.load_dataset")
def test_deita(_, format_dataset_mock, mock_tokenizer, tmp_path):
format_dataset_mock.return_value = [
{"instruction": "inst1", "output": "out1"},
{"instruction": "inst2", "output": "out2"},
{"instruction": "inst3", "output": "out3"},
]
deita = Deita(num_workers=0, download_dir=tmp_path)
assert isinstance(deita.prompt_style, AlpacaPromptStyle)
deita.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
deita.prepare_data()
deita.setup()
train_dataloader = deita.train_dataloader()
assert isinstance(train_dataloader.dataset, SFTDataset)
assert len(train_dataloader) == 2
val_dataloader = deita.val_dataloader()
assert isinstance(val_dataloader.dataset, SFTDataset)
assert len(val_dataloader) == 2
assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)
# has attributes from super class `LightningDataModule`
assert deita.prepare_data_per_node
================================================
FILE: tests/data/test_json.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from typing import Optional
import pytest
from litgpt.data import JSON
from litgpt.prompts import PromptStyle
@pytest.mark.parametrize("as_jsonl", [False, True])
def test_json(as_jsonl, tmp_path, mock_tokenizer):
class Style(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:
return f"X: {prompt} {kwargs['input']} Y:"
json_path = tmp_path / ("data.jsonl" if as_jsonl else "data.json")
mock_data = [
{"instruction": "Add", "input": "2+2", "output": "4"},
{"instruction": "Subtract", "input": "5-3", "output": "2"},
{"instruction": "Multiply", "input": "6*4", "output": "24"},
{"instruction": "Divide", "input": "10/2", "output": "5"},
{"instruction": "Exponentiate", "input": "2^3", "output": "8"},
{"instruction": "Square root", "input": "√9", "output": "3"},
]
with open(json_path, "w", encoding="utf-8") as fp:
if as_jsonl:
for line in mock_data:
json.dump(line, fp)
fp.write("\n")
else:
json.dump(mock_data, fp)
data = JSON(json_path, val_split_fraction=0.5, prompt_style=Style(), num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
assert len(train_dataloader) == 2
assert len(val_dataloader) == 2
train_data = list(train_dataloader)
val_data = list(val_dataloader)
assert train_data[0]["input_ids"].size(0) == 2
assert train_data[1]["input_ids"].size(0) == 1
assert val_data[0]["input_ids"].size(0) == 2
assert val_data[1]["input_ids"].size(0) == 1
assert mock_tokenizer.decode(train_data[0]["input_ids"][0]).startswith("X: Divide 10/2 Y:5")
assert mock_tokenizer.decode(train_data[0]["input_ids"][1]).startswith("X: Add 2+2 Y:4")
assert mock_tokenizer.decode(train_data[1]["input_ids"][0]).startswith("X: Multiply 6*4 Y:24")
assert mock_tokenizer.decode(val_data[0]["input_ids"][0]).startswith("X: Exponentiate 2^3 Y:8")
assert mock_tokenizer.decode(val_data[0]["input_ids"][1]).startswith("X: Subtract 5-3 Y:2")
assert mock_tokenizer.decode(val_data[1]["input_ids"][0]).startswith("X: Square root √9 Y:3")
assert isinstance(train_dataloader.dataset.prompt_style, Style)
assert isinstance(val_dataloader.dataset.prompt_style, Style)
# has attributes from super class `LightningDataModule`
assert data.prepare_data_per_node
def test_json_input_validation(tmp_path):
with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"):
JSON(tmp_path / "not exist")
with pytest.raises(ValueError, match="`val_split_fraction` should not be set"):
JSON(tmp_path, val_split_fraction=0.5)
data = JSON(tmp_path)
data.prepare_data() # does nothing
# Empty directory
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()
# Only train.json exists
(tmp_path / "train.json").touch()
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()
# When a single file is passed without val_split_fraction, it defaults to 0.05 and warns.
with pytest.warns(UserWarning, match="Defaulting to `val_split_fraction=0.05`"):
data = JSON(tmp_path / "train.json", val_split_fraction=None)
assert data.val_split_fraction == 0.05
@pytest.mark.parametrize("as_jsonl", [False, True])
def test_json_with_splits(as_jsonl, tmp_path, mock_tokenizer):
mock_train_data = [
{"instruction": "Add", "input": "2+2", "output": "4"},
{"instruction": "Subtract", "input": "5-3", "output": "2"},
{"instruction": "Exponentiate", "input": "2^3", "output": "8"},
]
mock_test_data = [
{"instruction": "Multiply", "input": "6*4", "output": "24"},
{"instruction": "Divide", "input": "10/2", "output": "5"},
]
train_file = tmp_path / ("train.jsonl" if as_jsonl else "train.json")
val_file = tmp_path / ("val.jsonl" if as_jsonl else "val.json")
with open(train_file, "w", encoding="utf-8") as fp:
if as_jsonl:
for line in mock_train_data:
json.dump(line, fp)
fp.write("\n")
else:
json.dump(mock_train_data, fp)
with open(val_file, "w", encoding="utf-8") as fp:
if as_jsonl:
for line in mock_test_data:
json.dump(line, fp)
fp.write("\n")
else:
json.dump(mock_test_data, fp)
data = JSON(tmp_path, num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
assert len(train_dataloader) == 2
assert len(val_dataloader) == 1
================================================
FILE: tests/data/test_lit_data.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
from unittest import mock
from unittest.mock import ANY
import pytest
from litgpt.data import LitData
@pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining")
@mock.patch("litgpt.data.lit_data.LitData._dataloader")
def test_input_dir_and_splits(dl_mock, tmp_path):
with pytest.raises(ValueError, match="If provided `split_names` must be a tuple of two strings"):
LitData(data_path=tmp_path, split_names=("train",))
# local dir, no splits
data = LitData(data_path=tmp_path)
data.train_dataloader()
dl_mock.assert_called_with(input_dir=str(tmp_path), train=True)
data.val_dataloader()
dl_mock.assert_called_with(input_dir=str(tmp_path), train=False)
# local dir, splits
data = LitData(data_path=tmp_path, split_names=("train", "val"))
data.train_dataloader()
dl_mock.assert_called_with(input_dir=str(tmp_path / "train"), train=True)
data.val_dataloader()
dl_mock.assert_called_with(input_dir=str(tmp_path / "val"), train=False)
# remote dir, splits
data = LitData(data_path="s3://mydataset/data", split_names=("train", "val"))
data.train_dataloader()
dl_mock.assert_called_with(input_dir="s3://mydataset/data/train", train=True)
data.val_dataloader()
dl_mock.assert_called_with(input_dir="s3://mydataset/data/val", train=False)
@pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining")
@mock.patch("litdata.streaming.StreamingDataset")
@mock.patch("litdata.streaming.StreamingDataLoader")
def test_dataset_args(streaming_dataloader_mock, streaming_dataset_mock, tmp_path):
data = LitData(data_path=tmp_path, seed=1000)
data.train_dataloader()
streaming_dataset_mock.assert_called_with(
input_dir=str(tmp_path),
item_loader=ANY,
shuffle=True,
seed=1000,
)
streaming_dataloader_mock.assert_called_with(
streaming_dataset_mock(),
batch_size=1,
pin_memory=True,
num_workers=8,
drop_last=True,
)
================================================
FILE: tests/data/test_longform.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from litgpt.data import LongForm
from litgpt.prompts import Longform as LongFormPromptStyle
def test_longform(mock_tokenizer, longform_path):
longform = LongForm(download_dir=longform_path, num_workers=0)
assert isinstance(longform.prompt_style, LongFormPromptStyle)
longform.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
longform.prepare_data()
longform.setup()
train_dataloader = longform.train_dataloader()
val_dataloader = longform.val_dataloader()
assert len(train_dataloader) == 9
assert len(val_dataloader) == 5
train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))
assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"}
for key in ["input_ids", "labels"]:
assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]"
assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]"
assert isinstance(train_dataloader.dataset.prompt_style, LongFormPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, LongFormPromptStyle)
# has attributes from super class `LightningDataModule`
assert longform.prepare_data_per_node
================================================
FILE: tests/data/test_openwebtext.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
from unittest import mock
from unittest.mock import ANY, call
import pytest
from litdata.streaming import StreamingDataLoader, StreamingDataset
from torch.utils.data import DataLoader
from litgpt.data import OpenWebText
@pytest.mark.skipif(sys.platform == "win32", reason="Not in the mood to add Windows support right now.")
@mock.patch("litdata.optimize")
@mock.patch("litdata.streaming.dataset.subsample_streaming_dataset", return_value=([], []))
@mock.patch("datasets.load_dataset")
def test_openwebtext(_, __, optimize_mock, tmp_path, mock_tokenizer):
data = OpenWebText(data_path=(tmp_path / "openwebtext"))
assert data.seq_length == 2048
assert data.batch_size == 1
data.connect(tokenizer=mock_tokenizer, batch_size=2, max_seq_length=1024)
assert data.seq_length == 1025
assert data.batch_size == 2
# Data does not exist, preprocess it
data.prepare_data()
optimize_mock.assert_has_calls(
[
call(
fn=ANY,
num_workers=ANY,
inputs=[],
output_dir=str(tmp_path / "openwebtext" / "train"),
chunk_bytes="200MB",
),
call(
fn=ANY,
num_workers=ANY,
inputs=[],
output_dir=str(tmp_path / "openwebtext" / "val"),
chunk_bytes="200MB",
),
]
)
optimize_mock.reset_mock()
# Data exists, already preprocessed
(tmp_path / "openwebtext" / "train").mkdir(parents=True)
(tmp_path / "openwebtext" / "val").mkdir(parents=True)
data.prepare_data()
optimize_mock.assert_not_called()
data.setup()
train_dataloader = data.train_dataloader()
assert isinstance(train_dataloader, StreamingDataLoader)
assert isinstance(train_dataloader.dataset, StreamingDataset)
val_dataloader = data.val_dataloader()
assert isinstance(val_dataloader, DataLoader)
assert isinstance(val_dataloader.dataset, StreamingDataset)
# has attributes from super class `LightningDataModule`
assert data.prepare_data_per_node
================================================
FILE: tests/data/test_textfiles.py
================================================
import json
import torch
from litdata import TokensLoader, optimize
from torch.utils._pytree import tree_map
from litgpt.data.text_files import TextFiles
class Tokenizer:
bos_id = 0
def encode(self, text, bos, eos):
assert bos
assert not eos
return [self.bos_id] + [ord(c) for c in text]
def tokenize(data):
for story in data:
yield torch.tensor(story)
def fake_chunk(path, data):
optimize(
fn=tokenize,
inputs=[data] * len(data),
output_dir=str(path),
num_workers=1,
chunk_bytes="200MB",
item_loader=TokensLoader(),
)
def test_textfiles_datamodule(tmp_path):
from litgpt.data.text_files import TextFiles
data_dir = tmp_path / "textfiles"
datamodule = TextFiles(train_data_path=data_dir, num_workers=1)
datamodule.connect(max_seq_length=2, tokenizer=Tokenizer())
# simulate `datamodule.prepare_data`
train_data_dir = data_dir / "train"
train_data_dir.mkdir(parents=True)
fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]])
datamodule.setup()
tr_dataloader = datamodule.train_dataloader()
tr_dataloader.shuffle = False
actual = tree_map(torch.Tensor.tolist, list(tr_dataloader))
# there is 1 sample per index in the data (13)
assert actual == [
[[73, 5, 0]],
[[12, 0, 23]],
[[5, 0, 1]],
[[0, 73, 5]],
[[1999, 0, 13]],
[[0, 1, 1999]],
[[1, 1999, 0]],
[[0, 23, 15]],
[[13, 12, 0]],
[[63, 0, 73]],
[[23, 15, 63]],
[[15, 63, 0]],
[[0, 13, 12]],
]
class MockTokenizer:
bos_id = 0
eos_id = 1
use_bos = True
def encode(self, text, bos=True, eos=False, device=None, max_length=-1):
# Simple: map each character to its ordinal + 2
tokens = [ord(c) + 2 for c in text]
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens.append(self.eos_id)
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.long, device=device)
def decode(self, tensor):
ids = tensor.tolist() if tensor.ndim > 0 else [tensor.item()]
chars = []
for tid in ids:
if tid == self.bos_id:
chars.append("")
elif tid == self.eos_id:
chars.append("")
else:
chars.append(chr(tid - 2))
return "".join(chars)
def decode_stream(self, token_stream, device=None):
for token in token_stream:
yield self.decode(token)
@property
def vocab_size(self):
return 130
def test_textfiles_token_loader(tmp_path):
# Create the directory for text files
data_dir = tmp_path / "textfiles"
data_dir.mkdir(parents=True, exist_ok=True)
# Write sample training data to the directory
sample_texts = ["hello world", "foo bar", "lorem ipsum"]
for i, text in enumerate(sample_texts):
(data_dir / f"{i}.txt").write_text(text)
datamodule = TextFiles(train_data_path=data_dir, num_workers=1)
datamodule.connect(max_seq_length=2, tokenizer=MockTokenizer())
datamodule.prepare_data()
# ensure training set uses tokens loader
index_json = data_dir / "train" / "index.json"
assert index_json.exists()
meta = json.loads(index_json.read_text())
assert meta["config"]["item_loader"] == "TokensLoader"
# ensure validation set uses tokens loader
index_json = data_dir / "val" / "index.json"
assert index_json.exists()
meta = json.loads(index_json.read_text())
assert meta["config"]["item_loader"] == "TokensLoader"
================================================
FILE: tests/data/test_tinyllama.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from unittest import mock
import pytest
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from torch.utils.data import DataLoader
from litgpt.data import TinyLlama
@mock.patch("litdata.streaming.dataset.subsample_streaming_dataset", return_value=([], []))
def test_tinyllama(_, tmp_path):
data = TinyLlama(data_path=(tmp_path / "data"))
assert data.seq_length == 2048
assert data.batch_size == 1
data.connect(batch_size=2, max_seq_length=1024)
assert data.seq_length == 1025
assert data.batch_size == 2
with pytest.raises(FileNotFoundError, match="The directory .*data/slimpajama/train does not exist"):
data.prepare_data()
(tmp_path / "data" / "slimpajama" / "train").mkdir(parents=True)
(tmp_path / "data" / "slimpajama" / "val").mkdir(parents=True)
(tmp_path / "data" / "starcoder").mkdir(parents=True)
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
assert isinstance(train_dataloader, StreamingDataLoader)
assert isinstance(train_dataloader.dataset, CombinedStreamingDataset)
val_dataloader = data.val_dataloader()
assert isinstance(val_dataloader, DataLoader)
assert isinstance(val_dataloader.dataset, StreamingDataset)
# has attributes from super class `LightningDataModule`
assert data.prepare_data_per_node
================================================
FILE: tests/data/test_tinystories.py
================================================
import json
import pytest
import torch
from litdata import optimize
from litdata.streaming import StreamingDataset, TokensLoader
from torch.utils._pytree import tree_map
def tokenize(data):
for story in data:
yield torch.tensor(story)
def fake_chunk(path, data):
optimize(
fn=tokenize,
inputs=[data] * len(data),
output_dir=str(path),
num_workers=1,
chunk_bytes="200MB",
item_loader=TokensLoader(),
)
@pytest.mark.parametrize(
("max_seq_len", "expected"),
[
(2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]),
(5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]),
(6, [[0, 23, 15, 63, 0, 73, 5]]),
(7, [[0, 23, 15, 63, 0, 73, 5, 0]]),
],
)
def test_pretok_dataset(tmp_path, max_seq_len, expected):
fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13]
assert len(fake_data) == 12
fake_chunk(tmp_path, [fake_data])
dataset = StreamingDataset(
input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False
)
actual = tree_map(torch.Tensor.tolist, list(dataset))
assert actual == expected
def test_tokenize(tmp_path, monkeypatch):
from litgpt.data.tinystories import tokenize
story1, story2 = "foo bar", " fun "
data = [{"story": story1}, {"story": story2}]
shard_path = tmp_path / "data.json"
with open(shard_path, "w", encoding="utf-8") as f:
json.dump(data, f)
class Tokenizer:
bos_id = 0
def encode(self, text, bos, eos):
assert bos
assert not eos
return [self.bos_id] + [ord(c) for c in text]
monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", "0")
monkeypatch.setenv("DATA_OPTIMIZER_NUM_WORKERS", "1")
data = tokenize(str(shard_path), Tokenizer())
assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]]
def test_tinystories_datamodule(tmp_path):
from litgpt.data.tinystories import TinyStories
data_dir = tmp_path / "tinystories"
datamodule = TinyStories(data_dir, seed=42, num_workers=1)
datamodule.connect(max_seq_length=2)
# simulate `datamodule.prepare_data`
train_data_dir = data_dir / "train"
train_data_dir.mkdir(parents=True)
fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]])
datamodule.setup()
tr_dataloader = datamodule.train_dataloader()
tr_dataloader.shuffle = False
actual = tree_map(torch.Tensor.tolist, list(tr_dataloader))
# there is 1 sample per index in the data (13)
assert actual == [
[[73, 5, 0]],
[[12, 0, 23]],
[[5, 0, 1]],
[[0, 73, 5]],
[[1999, 0, 13]],
[[0, 1, 1999]],
[[1, 1999, 0]],
[[0, 23, 15]],
[[13, 12, 0]],
[[63, 0, 73]],
[[23, 15, 63]],
[[15, 63, 0]],
[[0, 13, 12]],
]
================================================
FILE: tests/ext_thunder/__init__.py
================================================
import sys
from pathlib import Path
# support running without installing as a package, adding extensions to the Python path
wd = Path(__file__).parent.parent.parent.resolve()
if wd.is_dir():
sys.path.append(str(wd))
else:
import warnings
warnings.warn(f"Could not find extensions directory at {wd}")
================================================
FILE: tests/ext_thunder/test_thunder_distributed.py
================================================
import os
import sys
from pathlib import Path
from typing import Optional, Tuple, Union
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from litgpt.constants import _THUNDER_AVAILABLE
from litgpt.utils import _RunIf
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
if _THUNDER_AVAILABLE:
from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy
from extensions.thunder.strategies.thunder_fsdp import ThunderFSDPStrategy
@_RunIf(thunder=True)
def test_thunder_strategy_ddp_input_parsing():
with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
ThunderDDPStrategy(jit=False, executors=("python",))
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("choice", ["ddp", "fsdp"])
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_no_backward_sync_thunder(choice):
if choice == "ddp":
strategy = ThunderDDPStrategy()
elif choice == "fsdp":
strategy = ThunderFSDPStrategy()
else:
raise ValueError(f"Invalid choice: {choice}")
fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy)
fabric.launch()
# account for sharding in the case of FSDP
out_features = 1 if "ddp" in choice else fabric.world_size
model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device)
x = torch.randn(1, 1, device=fabric.device)
model = fabric.setup(model)
# 6 iters, 3 grad accumulation iters
for i, enabled in enumerate((True, True, False, True, True, False), 1):
x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32)
with fabric.no_backward_sync(model, enabled):
y = model(x)
fabric.backward(y.sum())
if not enabled:
# Math for the first 3 iters
#
# DistributedDataParallel
# (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9
# ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^
# rank0 rank1 allreduce
#
# thunder.distributed.ddp
# ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9
# ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^
# rank0 rank1 allreduce1 rank0 rank1 allreduce2
assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape
assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)
assert not hasattr(model.weight, "_thunder_fsdp_unsharded_grad")
model.weight.grad = None
elif choice == "fsdp":
assert model.weight._thunder_fsdp_unsharded_grad.shape == (2, 1)
assert model.weight.grad is None
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("jit", (False, True))
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_jit_ddp_before_setup(jit):
import thunder
fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit))
fabric.launch()
x = torch.randn(1, 1, device=fabric.device)
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)
tmodel = thunder.jit(model)
fmodel = fabric.setup(tmodel)
fmodel(x)
assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python()
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_strategy_ddp_setup_already_traced():
import thunder
device = torch.device("cuda")
x = torch.randn(1, 1, device=device)
model = torch.nn.Linear(1, 2, bias=False, device=device)
strategy = ThunderDDPStrategy()
tmodel = thunder.jit(model)
tmodel(x)
with pytest.raises(RuntimeError, match="already called"):
strategy.setup_module(tmodel)
@_RunIf(thunder=True)
def test_thunder_strategy_fsdp_input_parsing():
from thunder.distributed import FSDPBucketingStrategy, FSDPType
strategy = ThunderFSDPStrategy(bucketing_strategy="BlOcK", executors=("python",), sharding_strategy="zero3")
assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK
assert strategy.sharding_strategy is FSDPType.ZERO3
with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
ThunderFSDPStrategy(jit=False, executors=("python",))
@_RunIf(thunder=True)
def test_save_checkpoint_invalid_settings_raise(tmp_path):
strategy = ThunderFSDPStrategy(state_dict_type="full")
with pytest.raises(TypeError, match="not supported"):
strategy.save_checkpoint(tmp_path, {}, storage_options=object())
with pytest.raises(IsADirectoryError, match="path exists"):
strategy.save_checkpoint(tmp_path, {})
model = torch.nn.Linear(1, 1)
with pytest.raises(ValueError, match="Could not find"):
strategy.save_checkpoint(tmp_path / "foo", {})
model.use_fsdp = True
with pytest.raises(ValueError, match="Found multiple"):
strategy.save_checkpoint(tmp_path / "foo", {"model1": model, "model2": model})
with pytest.raises(ValueError, match="at least a model"):
strategy.load_checkpoint(tmp_path / "foo", {})
with pytest.raises(ValueError, match="must be a single file"):
strategy.load_checkpoint(tmp_path, model)
optimizer = torch.optim.Adam(model.parameters())
with pytest.raises(NotImplementedError, match="not supported"):
strategy.load_checkpoint(tmp_path, optimizer)
with pytest.raises(ValueError, match="Found multiple"):
strategy.load_checkpoint(tmp_path / "foo", {"model1": model, "model2": model})
with pytest.raises(ValueError, match="Could not find"):
strategy.load_checkpoint(tmp_path / "foo", {"foo": 1})
class Submodule(torch.nn.Module):
def __init__(self, h: int):
super().__init__()
self.l = torch.nn.Linear(4, h * 2, bias=False)
def forward(self, x):
# defined just because preprocessing fails otherwise
...
class MyModel(torch.nn.Module):
def __init__(self, h: int):
super().__init__()
self.register_buffer("buf", torch.tensor(0))
self.l = torch.nn.Linear(2, h)
self.inner = Submodule(h)
def forward(self):
# defined just because preprocessing fails otherwise
...
def reset_parameters(self):
self.buf = torch.empty_like(self.buf)
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_materialize_meta_tensors():
strategy = ThunderFSDPStrategy()
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
with fabric.init_module(empty_init=True):
model = MyModel(2)
model = fabric.setup(model)
# all parameters were moved
assert len(list(model.parameters())) == 3
assert all(p.device.type == "cuda" for p in model.parameters())
# buffers were moved too
assert model.buf.device.type == "cuda"
class StatefulThing:
def state_dict(self):
return {"thing": 1}
def load_state_dict(self, state_dict):
assert state_dict == self.state_dict()
class TensorLike:
def __init__(self, device: Optional[Union[str, torch.device]] = None, shape: Optional[Tuple[int, ...]] = None):
self.device = torch.device(device) if device is not None else None
self.shape = torch.Size(shape) if shape is not None else None
def __eq__(self, other):
return (
isinstance(other, torch.Tensor)
and (self.device is None or other.device == self.device)
and (self.shape is None or other.shape == self.shape)
)
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_save_load_full_checkpoint(tmp_path):
strategy = ThunderFSDPStrategy(state_dict_type="full", broadcast_from=0)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = MyModel(4)
expected = model.state_dict()
# save a sharded model
model = fabric.setup(model)
state = {"model": model, "stateful": StatefulThing(), "primitive": 123}
checkpoint_path = tmp_path / "foo"
fabric.save(checkpoint_path, state)
# assert the file contents
if fabric.global_rank == 0:
checkpoint = torch.load(checkpoint_path)
# cpu_offload is enabled by default
assert checkpoint == {
"model": {
"buf": TensorLike("cpu", tuple()),
"inner.l.weight": TensorLike("cpu", (8, 4)),
"l.bias": TensorLike("cpu", (4,)),
"l.weight": TensorLike("cpu", (4, 2)),
},
"stateful": {"thing": 1},
"primitive": 123,
}
torch.testing.assert_close(checkpoint["model"], expected)
# load its weights into a different sharded model
model = MyModel(4)
model = fabric.setup(model)
state = {"model": model, "stateful": StatefulThing(), "primitive": 321}
fabric.load(checkpoint_path, state)
from thunder.distributed import _unshard_params
# unshard this model's parameters to compare with the original state dict before sharding
_unshard_params(model, model.process_group_for_ddp, True)
# we loaded rank 0's weights, so this would fail in the other ranks
if fabric.global_rank == 0:
actual = model.state_dict()
# `_unshard_params` doesn't offload buffers at the moment
assert actual["buf"].device.type == "cuda"
actual["buf"] = actual["buf"].to(device="cpu")
torch.testing.assert_close(actual, expected)
assert state["primitive"] == 123
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_load_full_checkpoint_only_model(tmp_path):
strategy = ThunderFSDPStrategy()
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
checkpoint_path = tmp_path / "foo"
checkpoint_path = fabric.broadcast(checkpoint_path)
if fabric.global_rank == 0:
model = MyModel(4)
expected = model.state_dict()
torch.save(expected, checkpoint_path)
fabric.barrier()
expected = torch.load(checkpoint_path)
# before sharding
model = MyModel(4)
fabric.load_raw(checkpoint_path, model)
torch.testing.assert_close(model.state_dict(), expected)
# after sharding
model = MyModel(4)
model = fabric.setup(model)
fabric.load_raw(checkpoint_path, model)
from thunder.distributed import _unshard_params
# unshard this model's parameters to compare with the original state dict before sharding
_unshard_params(model, model.process_group_for_ddp, True)
actual = model.state_dict()
# `_unshard_params` doesn't offload buffers at the moment
assert actual["buf"].device.type == "cuda"
actual["buf"] = actual["buf"].to(device="cpu")
torch.testing.assert_close(actual, expected)
def distributed_ckpt_to_regular(path):
"""From ``torch.distributed.checkpoint.format_utils.dcp_to_torch_save``."""
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
else:
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import TensorStorageMetadata
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_up_planner(self, state_dict, metadata, is_coordinator):
assert not state_dict
# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype)
if k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v
super().set_up_planner(state_dict, metadata, is_coordinator)
state_dict = {}
storage_reader = FileSystemReader(path)
_load_state_dict(state_dict, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True)
return state_dict
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_save_load_sharded_checkpoint(tmp_path):
strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = MyModel(4)
expected = model.state_dict()
# save a sharded model
model = fabric.setup(model)
state = {"model": model, "stateful": StatefulThing(), "primitive": 123}
fabric.save(tmp_path, state)
# assert the file contents
if fabric.global_rank == 0:
assert set(os.listdir(tmp_path)) == {"meta.pt", "__1_0.distcp", "__0_0.distcp", ".metadata"}
metadata = torch.load(tmp_path / "meta.pt")
assert metadata == {"stateful": {"thing": 1}, "primitive": 123}
checkpoint = distributed_ckpt_to_regular(tmp_path)
# cpu_offload is enabled by default
assert checkpoint == {
"model": {
"buf": TensorLike("cpu", tuple()),
"inner.l.weight": TensorLike("cpu", (8, 4)),
"l.bias": TensorLike("cpu", (4,)),
"l.weight": TensorLike("cpu", (4, 2)),
}
}
torch.testing.assert_close(checkpoint["model"], expected)
# load its weights into a different sharded model
model = MyModel(4)
model = fabric.setup(model)
state = {"model": model, "stateful": StatefulThing(), "primitive": 321}
fabric.load(tmp_path, state)
from thunder.distributed import _unshard_params
# unshard this model's parameters to compare with the original state dict before sharding
_unshard_params(model, model.process_group_for_ddp, True)
# we loaded rank 0's weights, so this would fail in the other ranks
if fabric.global_rank == 0:
actual = model.state_dict()
# `_unshard_params` doesn't offload buffers at the moment
assert actual["buf"].device.type == "cuda"
actual["buf"] = actual["buf"].to(device="cpu")
torch.testing.assert_close(actual, expected)
assert state["primitive"] == 123
@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("jit", (False, True))
@pytest.mark.xfail(TypeError, reason="temporally disabled until resolved with Thunder")
def test_jit_fsdp_before_setup(jit):
import thunder
fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit))
fabric.launch()
x = torch.randn(1, 1, device=fabric.device)
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)
tmodel = thunder.jit(model)
fmodel = fabric.setup(tmodel)
fmodel(x)
assert "all_gather" in thunder.last_traces(tmodel)[-1].python()
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_strategy_fsdp_setup_already_traced():
import thunder
device = torch.device("cuda")
x = torch.randn(1, 1, device=device)
model = torch.nn.Linear(1, 2, bias=False, device=device)
strategy = ThunderFSDPStrategy()
tmodel = thunder.jit(model)
tmodel(x)
with pytest.raises(RuntimeError, match="already called"):
strategy.setup_module(tmodel)
================================================
FILE: tests/ext_thunder/test_thunder_networks.py
================================================
"""Run thunder tests as part of LitGPT CI"""
from litgpt.constants import _THUNDER_AVAILABLE
if _THUNDER_AVAILABLE:
from thunder.tests.test_networks import * # noqa: F403
else:
print("Skipping test_thunder_networks.py (thunder not available)")
================================================
FILE: tests/ext_thunder/test_thunder_pretrain.py
================================================
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest.mock import Mock
import torch
from torch.utils.data import DataLoader
from litgpt import Config
from litgpt.args import EvalArgs, TrainArgs
from litgpt.constants import _THUNDER_AVAILABLE
from litgpt.utils import _RunIf
if _THUNDER_AVAILABLE:
import extensions.thunder.pretrain as thunder_pretrain
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_pretrain_thunder(tmp_path, monkeypatch):
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
monkeypatch.setattr(thunder_pretrain, "get_dataloaders", Mock(return_value=(dataloader, dataloader)))
monkeypatch.setattr(thunder_pretrain, "save_hyperparameters", Mock())
out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
thunder_pretrain.setup(
devices=1,
model_config=model_config,
out_dir=out_dir,
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1),
optimizer="AdamW",
)
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
# the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here
assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "model_config.yaml"}
assert (out_dir / "logs" / "tensorboard" / "version_0").is_dir()
logs = stdout.getvalue()
assert logs.count("(step)") == 4
assert logs.count("val loss") == 4
assert "Total parameters: 1,888" in logs
================================================
FILE: tests/ext_thunder/test_unsloth_executor.py
================================================
import pytest
import torch
from litgpt import GPT, Config
from litgpt.model import apply_rope, build_rope_cache
from litgpt.utils import _RunIf, chunked_cross_entropy
@_RunIf(min_cuda_gpus=1, thunder=True)
@pytest.mark.parametrize("reduction", ["none", "mean"])
def test_unsloth_cross_entropy(reduction):
import thunder
from extensions.thunder.unsloth.executor import unsloth_ex
logits = torch.randn(64, 128, device="cuda", requires_grad=True)
labels = torch.randint(128, (64,), device="cuda")
def foo(logits, labels):
# this is the variant supported by unsloth.
# if different arguments are used, the implementation would no be lowered to unsloth and instead would get
# decomposed
return torch.nn.functional.cross_entropy(logits, labels, reduction=reduction, ignore_index=-100)
cfoo = thunder.jit(foo, executors=[unsloth_ex])
actual = cfoo(logits, labels)
trace_str = str(thunder.last_traces(cfoo)[-1])
assert "unsloth_cross_entropy" in trace_str and "backward" not in trace_str
trace_str = str(thunder.last_backward_traces(cfoo)[-1])
assert "unsloth_cross_entropy_backward" in trace_str
expected = foo(logits, labels)
torch.testing.assert_close(actual, expected)
(actual_grad,) = torch.autograd.grad(actual.sum(), logits)
trace_str = str(thunder.last_backward_traces(cfoo)[-1])
assert "unsloth_cross_entropy_backward" in trace_str
out = foo(logits, labels)
assert logits.grad is None
(expected_grad,) = torch.autograd.grad(out.sum(), logits)
torch.testing.assert_close(actual_grad, expected_grad)
@pytest.mark.skip(reason="out of date")
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_unsloth_rope():
import thunder
from extensions.thunder.unsloth.executor import unsloth_ex
B, nh, T, hs = 2, 32, 64, 16
cos, sin = build_rope_cache(T, hs, device="cuda")
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
q = torch.rand((B, nh, T, hs), device="cuda", requires_grad=True)
def foo(x, cos, sin):
return apply_rope(x, cos, sin)
cfoo = thunder.jit(foo, executors=[unsloth_ex])
actual = cfoo(q, cos, sin)
trace_str = str(thunder.last_traces(cfoo)[-1])
assert "unsloth_apply_rope" in trace_str and "backward" not in trace_str
trace_str = str(thunder.last_backward_traces(cfoo)[-1])
assert "unsloth_apply_rope_backward" in trace_str
expected = foo(q, cos, sin)
torch.testing.assert_close(actual, expected)
(actual_grad,) = torch.autograd.grad(actual.sum(), q)
(expected_grad,) = torch.autograd.grad(expected.sum(), q)
torch.testing.assert_close(actual_grad, expected_grad)
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_unsloth_swiglu():
import thunder
from extensions.thunder.unsloth.executor import ThunderLLaMAMLP, unsloth_ex
from litgpt import Config
from litgpt.model import LLaMAMLP
config = Config.from_name("Llama-2-7b-hf")
with torch.device("cuda"):
x = torch.randn(2, 16, config.n_embd, requires_grad=True)
mlp = LLaMAMLP(config)
# monkeypatching was successful
assert isinstance(mlp, ThunderLLaMAMLP)
cmlp = thunder.jit(mlp, executors=[unsloth_ex])
actual = cmlp(x)
trace_str = str(thunder.last_traces(cmlp)[-1])
assert "unsloth_swiglu" in trace_str and "backward" not in trace_str
trace_str = str(thunder.last_backward_traces(cmlp)[-1])
assert "unsloth_swiglu_backward" in trace_str
expected = mlp(x)
torch.testing.assert_close(actual, expected)
(actual_grad,) = torch.autograd.grad(actual.sum(), x)
(expected_grad,) = torch.autograd.grad(expected.sum(), x)
torch.testing.assert_close(actual_grad, expected_grad)
@_RunIf(min_cuda_gpus=1, thunder=True)
def test_unsloth_gpt():
import thunder
from extensions.thunder.unsloth.executor import unsloth_ex
def forward_and_loss(model, input_ids, targets):
logits = model(input_ids)
return chunked_cross_entropy(logits, targets, chunk_size=0)
cfn = thunder.jit(forward_and_loss, executors=[unsloth_ex])
device = torch.device("cuda")
config = Config(
vocab_size=320,
padding_multiple=64,
n_layer=2,
n_head=4,
n_embd=64,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=1376,
)
with device:
model = GPT(config)
input_ids = torch.randint(1, 10, (2, 3))
targets = torch.randint(0, 10, (2, 3))
loss = cfn(model, input_ids, targets)
assert isinstance(loss, torch.Tensor)
fwd = thunder.last_traces(cfn)
bwd = thunder.last_backward_traces(cfn)
fwd_str, bwd_str = fwd[-1].python(), bwd[-1].python()
assert "unsloth_cross_entropy" in fwd_str
assert "unsloth_cross_entropy_backward" in bwd_str
assert "unsloth_apply_rope" in fwd_str
assert "unsloth_apply_rope_backward" in bwd_str
assert "unsloth_swiglu" in fwd_str
assert "unsloth_swiglu_backward" in bwd_str
================================================
FILE: tests/generate/__init__.py
================================================
================================================
FILE: tests/generate/test_adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import re
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from unittest.mock import ANY, Mock, call
import pytest
import torch
import yaml
skip_in_ci_on_macos = pytest.mark.skipif(
sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.",
)
@skip_in_ci_on_macos
@pytest.mark.parametrize("version", ("v1", "v2"))
def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):
if version == "v1":
import litgpt.generate.adapter as generate
else:
import litgpt.generate.adapter_v2 as generate
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(yaml.dump(config))
monkeypatch.setattr(generate, "lazy_load", Mock())
monkeypatch.setattr(generate.GPT, "load_state_dict", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([[3, 2, 1]])
monkeypatch.setattr(generate, "generate", generate_mock)
num_samples = 1
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert (
generate_mock.mock_calls
== [call(ANY, tensor_like, 101, temperature=2.0, top_k=2, top_p=0.9, eos_id=ANY)] * num_samples
)
expected_output = "foo bar baz\n" * num_samples
# Allow for the config to be printed before the expected repeated strings.
pattern = rf".*^{re.escape(expected_output.strip())}$.*"
assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)
err_value = err.getvalue()
expected_parts = [
"'padded_vocab_size': 512",
"'n_layer': 2",
"'n_head': 4",
"'head_size': 2",
"'n_embd': 8",
]
assert all(part in err_value for part in expected_parts)
@pytest.mark.parametrize("version", ("", "_v2"))
def test_cli(version):
args = ["litgpt", f"generate_adapter{version}", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "For models finetuned with" in output
================================================
FILE: tests/generate/test_main.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import re
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import ANY, Mock, call
import pytest
import torch
import yaml
import litgpt.generate.base as generate
from litgpt import GPT, Config
from litgpt.generate.base import sample
skip_in_ci_on_macos = pytest.mark.skipif(
sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.",
)
@pytest.mark.parametrize(
"max_seq_length", (pytest.param(10, marks=pytest.mark.xfail(raises=NotImplementedError, strict=True)), 20 + 5)
)
def test_generate(max_seq_length):
import lightning as L
L.seed_everything(1234)
T = 5
input_idx = torch.arange(0, T)
config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
model = GPT(config)
model.max_seq_length = max_seq_length
model.set_kv_cache(batch_size=1)
max_new_tokens = 20
multinomial_results = []
def multinomial(*args, **kwargs):
out = torch.multinomial(*args, **kwargs, num_samples=1)
multinomial_results.append(out)
return out
with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial):
out = generate.generate(model, input_idx, T + max_new_tokens, top_k=1)
assert out.size(0) == T + max_new_tokens, (out.size(0), T + max_new_tokens)
multinomial_results = torch.hstack(multinomial_results)
expected = torch.cat((input_idx, multinomial_results))
assert out.shape == expected.shape, (out.shape, expected.shape)
torch.testing.assert_close(out, expected)
@skip_in_ci_on_macos
def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(yaml.dump(config))
module_mock = Mock()
module_mock.config.block_size = 128
load_mock = Mock()
load_mock.return_value = load_mock
monkeypatch.setattr(generate, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([3, 2, 1])
monkeypatch.setattr(generate, "generate", generate_mock)
num_samples = 2
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert (
generate_mock.mock_calls
== [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)]
* num_samples
)
expected_output = "foo bar baz\n" * num_samples
# Allow for the config to be printed before the expected repeated strings.
pattern = rf".*^{re.escape(expected_output.strip())}$.*"
assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)
err_value = err.getvalue()
expected_parts = [
"'padded_vocab_size': 512",
"'n_layer': 2",
"'n_head': 4",
]
assert all(part in err_value for part in expected_parts)
def test_cli():
args = ["litgpt", "generate", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Default generation option" in output
@pytest.mark.parametrize("temperature", (0.0, 1.0, 0.5))
def test_sample(temperature):
# shape: 2x3x5
logits = torch.tensor(
[
[[24, 4, 98, 77, 47], [65, 70, 32, 67, 24], [92, 32, 88, 36, 62]],
[[85, 79, 57, 68, 50], [89, 46, 72, 45, 32], [68, 96, 68, 24, 36]],
],
dtype=torch.float32,
)
token = sample(logits, temperature=temperature, top_p=0.8)
assert token.shape == (1,)
# sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported
assert token.tolist() == [0]
def test_generate_different_results_with_different_top_p():
config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)
model = GPT(config)
model.max_seq_length = 50
model.set_kv_cache(batch_size=1)
torch.manual_seed(123)
input_idx = torch.randint(10, size=(1,))
torch.manual_seed(123)
output1 = generate.generate(model, input_idx, 20, top_p=1.0)
torch.manual_seed(123)
output2 = generate.generate(model, input_idx, 20, top_p=0.1)
assert not torch.equal(output1, output2)
================================================
FILE: tests/generate/test_sequentially.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import itertools
import subprocess
import sys
from dataclasses import asdict
from pathlib import Path
from re import escape
import pytest
import torch
import yaml
from lightning import Fabric
from litgpt import Config
from litgpt.generate.sequentially import (
chunk_sizes,
layer_to_device,
replace_device,
sequential,
)
from litgpt.model import GPT, Block
from litgpt.scripts.download import download_from_hub
from litgpt.utils import _RunIf
from .utils import find_forward_hooks
@pytest.mark.parametrize(
("n_layer", "devices", "expected"),
[
(6, 1, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}),
(6, 2, {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1}),
(6, 3, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}),
(6, 4, {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 3}),
(6, 5, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}),
(6, 6, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}),
],
)
def test_layer_to_device(n_layer, devices, expected):
with torch.device("meta"):
model = GPT.from_name("pythia-14m", n_layer=n_layer)
c_sizes = chunk_sizes(n_layer, devices)
actual = layer_to_device(model, Block, chunk_sizes=c_sizes)
expected = {f"transformer.h.{i}": v for i, v in expected.items()}
assert actual == expected
def path_to_device(model):
return {k: str(v.device) for k, v in itertools.chain(model.named_parameters(), model.named_buffers())}
def test_replace_device():
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("foo", torch.tensor(1, device="cpu"))
self.register_buffer("bar", torch.tensor(1, device="cpu"))
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.modules = torch.nn.ModuleDict(
{
"module1": torch.nn.Linear(1, 1, bias=True, device="meta"),
"module2": torch.nn.Linear(1, 1, bias=False, device="cpu"),
}
)
self.submodule = Submodule()
model = MyModel()
assert path_to_device(model) == {
"modules.module1.bias": "meta",
"modules.module1.weight": "meta",
"modules.module2.weight": "cpu",
"submodule.bar": "cpu",
"submodule.foo": "cpu",
}
model = replace_device(model, torch.device("cpu"), torch.device("meta"))
assert path_to_device(model) == {
"modules.module1.bias": "meta",
"modules.module1.weight": "meta",
"modules.module2.weight": "meta",
"submodule.bar": "meta",
"submodule.foo": "meta",
}
model = MyModel()
model.submodule.bar = model.submodule.bar.to("meta")
with pytest.raises(
ValueError,
match=escape("multiple devices: {'submodule.foo': device(type='cpu'), 'submodule.bar': device(type='meta')}"),
):
replace_device(model, torch.device("cpu"), torch.device("meta"))
def _test_model_1device(accelerator):
fabric = Fabric(accelerator=accelerator, devices=1)
with torch.device("meta"):
model = GPT.from_name("pythia-14m", n_layer=2)
model = sequential(model, fabric.device, 15, 1)
device_str = str(fabric.device)
assert path_to_device(model) == {
"cos": device_str,
"sin": device_str,
"lm_head.weight": device_str,
"transformer.h.0.attn.qkv.bias": device_str,
"transformer.h.0.attn.qkv.weight": device_str,
"transformer.h.0.attn.proj.bias": device_str,
"transformer.h.0.attn.proj.weight": device_str,
"transformer.h.0.mlp.fc.bias": device_str,
"transformer.h.0.mlp.fc.weight": device_str,
"transformer.h.0.mlp.proj.bias": device_str,
"transformer.h.0.mlp.proj.weight": device_str,
"transformer.h.0.norm_1.bias": device_str,
"transformer.h.0.norm_1.weight": device_str,
"transformer.h.0.norm_2.bias": device_str,
"transformer.h.0.norm_2.weight": device_str,
"transformer.h.0.attn.kv_cache.k": device_str,
"transformer.h.0.attn.kv_cache.v": device_str,
"transformer.h.1.attn.qkv.bias": device_str,
"transformer.h.1.attn.qkv.weight": device_str,
"transformer.h.1.attn.proj.bias": device_str,
"transformer.h.1.attn.proj.weight": device_str,
"transformer.h.1.mlp.fc.bias": device_str,
"transformer.h.1.mlp.fc.weight": device_str,
"transformer.h.1.mlp.proj.bias": device_str,
"transformer.h.1.mlp.proj.weight": device_str,
"transformer.h.1.norm_1.bias": device_str,
"transformer.h.1.norm_1.weight": device_str,
"transformer.h.1.norm_2.bias": device_str,
"transformer.h.1.norm_2.weight": device_str,
"transformer.h.1.attn.kv_cache.k": device_str,
"transformer.h.1.attn.kv_cache.v": device_str,
"transformer.ln_f.bias": device_str,
"transformer.ln_f.weight": device_str,
"transformer.wte.weight": device_str,
}
assert model.max_seq_length == 15
@_RunIf(min_cuda_gpus=1)
def test_model_1device_cuda():
_test_model_1device("cuda")
def test_model_1device_cpu():
_test_model_1device("cpu")
@_RunIf(min_cuda_gpus=2)
def test_model_forward_hooks():
fabric = Fabric(accelerator="cuda", devices=1)
with torch.device("meta"):
model = GPT.from_name("pythia-14m") # 6 layers
model = sequential(model, fabric.device, max_seq_length=15, devices=2)
hooks = find_forward_hooks(model)
actual = path_to_device(model)
assert actual == {
"lm_head.weight": "cuda:0",
"transformer.wte.weight": "cuda:0",
"transformer.h.0.norm_1.weight": "cuda:0",
"transformer.h.0.norm_1.bias": "cuda:0",
"transformer.h.0.attn.qkv.weight": "cuda:0",
"transformer.h.0.attn.qkv.bias": "cuda:0",
"transformer.h.0.attn.proj.weight": "cuda:0",
"transformer.h.0.attn.proj.bias": "cuda:0",
"transformer.h.0.norm_2.weight": "cuda:0",
"transformer.h.0.norm_2.bias": "cuda:0",
"transformer.h.0.mlp.fc.weight": "cuda:0",
"transformer.h.0.mlp.fc.bias": "cuda:0",
"transformer.h.0.mlp.proj.weight": "cuda:0",
"transformer.h.0.mlp.proj.bias": "cuda:0",
"transformer.h.1.norm_1.weight": "cuda:0",
"transformer.h.1.norm_1.bias": "cuda:0",
"transformer.h.1.attn.qkv.weight": "cuda:0",
"transformer.h.1.attn.qkv.bias": "cuda:0",
"transformer.h.1.attn.proj.weight": "cuda:0",
"transformer.h.1.attn.proj.bias": "cuda:0",
"transformer.h.1.norm_2.weight": "cuda:0",
"transformer.h.1.norm_2.bias": "cuda:0",
"transformer.h.1.mlp.fc.weight": "cuda:0",
"transformer.h.1.mlp.fc.bias": "cuda:0",
"transformer.h.1.mlp.proj.weight": "cuda:0",
"transformer.h.1.mlp.proj.bias": "cuda:0",
"transformer.h.2.norm_1.weight": "cuda:0",
"transformer.h.2.norm_1.bias": "cuda:0",
"transformer.h.2.attn.qkv.weight": "cuda:0",
"transformer.h.2.attn.qkv.bias": "cuda:0",
"transformer.h.2.attn.proj.weight": "cuda:0",
"transformer.h.2.attn.proj.bias": "cuda:0",
"transformer.h.2.norm_2.weight": "cuda:0",
"transformer.h.2.norm_2.bias": "cuda:0",
"transformer.h.2.mlp.fc.weight": "cuda:0",
"transformer.h.2.mlp.fc.bias": "cuda:0",
"transformer.h.2.mlp.proj.weight": "cuda:0",
"transformer.h.2.mlp.proj.bias": "cuda:0",
"transformer.h.3.norm_1.weight": "cuda:1",
"transformer.h.3.norm_1.bias": "cuda:1",
"transformer.h.3.attn.qkv.weight": "cuda:1",
"transformer.h.3.attn.qkv.bias": "cuda:1",
"transformer.h.3.attn.proj.weight": "cuda:1",
"transformer.h.3.attn.proj.bias": "cuda:1",
"transformer.h.3.norm_2.weight": "cuda:1",
"transformer.h.3.norm_2.bias": "cuda:1",
"transformer.h.3.mlp.fc.weight": "cuda:1",
"transformer.h.3.mlp.fc.bias": "cuda:1",
"transformer.h.3.mlp.proj.weight": "cuda:1",
"transformer.h.3.mlp.proj.bias": "cuda:1",
"transformer.h.4.norm_1.weight": "cuda:1",
"transformer.h.4.norm_1.bias": "cuda:1",
"transformer.h.4.attn.qkv.weight": "cuda:1",
"transformer.h.4.attn.qkv.bias": "cuda:1",
"transformer.h.4.attn.proj.weight": "cuda:1",
"transformer.h.4.attn.proj.bias": "cuda:1",
"transformer.h.4.norm_2.weight": "cuda:1",
"transformer.h.4.norm_2.bias": "cuda:1",
"transformer.h.4.mlp.fc.weight": "cuda:1",
"transformer.h.4.mlp.fc.bias": "cuda:1",
"transformer.h.4.mlp.proj.weight": "cuda:1",
"transformer.h.4.mlp.proj.bias": "cuda:1",
"transformer.h.5.norm_1.weight": "cuda:1",
"transformer.h.5.norm_1.bias": "cuda:1",
"transformer.h.5.attn.qkv.weight": "cuda:1",
"transformer.h.5.attn.qkv.bias": "cuda:1",
"transformer.h.5.attn.proj.weight": "cuda:1",
"transformer.h.5.attn.proj.bias": "cuda:1",
"transformer.h.5.norm_2.weight": "cuda:1",
"transformer.h.5.norm_2.bias": "cuda:1",
"transformer.h.5.mlp.fc.weight": "cuda:1",
"transformer.h.5.mlp.fc.bias": "cuda:1",
"transformer.h.5.mlp.proj.weight": "cuda:1",
"transformer.h.5.mlp.proj.bias": "cuda:1",
"transformer.ln_f.weight": "cuda:0",
"transformer.ln_f.bias": "cuda:0",
"cos": "cuda:0",
"sin": "cuda:0",
"transformer.h.0.attn.kv_cache.k": "cuda:0",
"transformer.h.0.attn.kv_cache.v": "cuda:0",
"transformer.h.1.attn.kv_cache.k": "cuda:0",
"transformer.h.1.attn.kv_cache.v": "cuda:0",
"transformer.h.2.attn.kv_cache.k": "cuda:0",
"transformer.h.2.attn.kv_cache.v": "cuda:0",
"transformer.h.3.attn.kv_cache.k": "cuda:1",
"transformer.h.3.attn.kv_cache.v": "cuda:1",
"transformer.h.4.attn.kv_cache.k": "cuda:1",
"transformer.h.4.attn.kv_cache.v": "cuda:1",
"transformer.h.5.attn.kv_cache.k": "cuda:1",
"transformer.h.5.attn.kv_cache.v": "cuda:1",
}
assert hooks == {
"transformer.h.3": [("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {})],
"transformer.h.4": [("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {})],
"transformer.h.5": [
("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {}),
("forward_hook", "move_block_output", (torch.device(type="cuda", index=0),), {}),
],
}
root = Path(__file__).parent.parent.resolve()
@_RunIf(min_cuda_gpus=2)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
def test_base_with_sequentially(tmp_path):
# download the tokenizer
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
checkpoint_dir = tmp_path / "EleutherAI/pythia-14m"
# save the config
config = Config.from_name("pythia-14m")
(checkpoint_dir / "model_config.yaml").write_text(yaml.dump(asdict(config)))
# create a state dict to load from
torch.save(GPT(config).state_dict(), checkpoint_dir / "lit_model.pth")
args = [
str(checkpoint_dir),
"--num_samples=1",
"--max_new_tokens=10",
"--precision=16-true",
"--temperature=0.0",
]
env = {"CUDA_VISIBLE_DEVICES": "0,1"}
sequential_stdout = subprocess.check_output(
[sys.executable, "-m", "litgpt", "generate_sequentially", *args],
env=env,
cwd=root,
).decode()
assert "What food do llamas eat?" in sequential_stdout
def test_cli():
args = ["litgpt", "generate_sequentially", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generation script that partitions layers across" in output
================================================
FILE: tests/generate/test_tp.py
================================================
import subprocess
import sys
from dataclasses import asdict, replace
from pathlib import Path
from unittest.mock import Mock
import pytest
import torch
import yaml
from litgpt import GPT, Config
from litgpt.generate.tp import tensor_parallel, tensor_parallel_linear
from litgpt.scripts.download import download_from_hub
from litgpt.utils import _RunIf
from .utils import find_forward_hooks
def test_tensor_parallel_linear():
fabric = Mock()
fabric.world_size = 4
fabric.global_rank = 2
def get_linear(bias=True):
linear = torch.nn.Linear(8, 8, bias=bias)
linear.weight.data = torch.arange(64, dtype=torch.float32).reshape(8, 8)
if bias:
linear.bias.data = torch.arange(8, dtype=torch.float32)
return linear
linear = get_linear()
tensor_parallel_linear(fabric, linear, "colwise")
expected = torch.arange(32, 48, dtype=torch.float32).reshape(2, 8)
torch.testing.assert_close(linear.weight, expected)
expected = torch.arange(4, 6, dtype=torch.float32)
torch.testing.assert_close(linear.bias, expected)
linear = get_linear(bias=False)
tensor_parallel_linear(fabric, linear, "rowwise")
expected = torch.arange(4, 62, 8, dtype=torch.float32).reshape(8, 1)
expected = torch.cat([expected, expected + 1], dim=1)
torch.testing.assert_close(linear.weight, expected)
assert linear.bias is None
@pytest.mark.parametrize(
("name", "expected"),
[
(
"Llama-2-70b-hf",
{
"transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.0.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
},
),
(
"falcon-180B",
{
"transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.0.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.mlp": [("forward_hook", "all_reduce_output", (8,), {})],
},
),
(
"Mixtral-8x7B-v0.1",
{
"transformer.h.0.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.0.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.0.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.1.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.attn": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.mlp.experts.0": [("forward_hook", "all_reduce_output", (8,), {})],
"transformer.h.2.mlp.experts.1": [("forward_hook", "all_reduce_output", (8,), {})],
},
),
],
)
def test_tensor_parallel_llama(name, expected):
fabric = Mock()
fabric.world_size = 8
fabric.global_rank = 1
with torch.device("meta"):
model = GPT.from_name(name, n_layer=3, n_expert=2)
config = replace(model.config) # make a copy
model = tensor_parallel(fabric, model)
hooks = find_forward_hooks(model)
assert hooks == expected
assert model.config.n_embd * 8 == config.n_embd
assert model.config.n_head * 8 == config.n_head
assert model.config.n_query_groups * 8 == config.n_query_groups
root = Path(__file__).parent.parent.resolve()
@_RunIf(min_cuda_gpus=2)
def test_tp(tmp_path):
# download the tokenizer
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
checkpoint_dir = tmp_path / "EleutherAI/pythia-14m"
# save the config
config = Config.from_name("pythia-14m")
(checkpoint_dir / "model_config.yaml").write_text(yaml.dump(asdict(config)))
# create a state dict to load from
torch.save(GPT(config).state_dict(), checkpoint_dir / "lit_model.pth")
args = [
str(checkpoint_dir),
"--num_samples=1",
"--max_new_tokens=10",
"--precision=16-true",
"--temperature=0.0",
]
env = {"CUDA_VISIBLE_DEVICES": "0,1"}
tp_stdout = subprocess.check_output(
[sys.executable, "-m", "litgpt", "generate_tp", *args], env=env, cwd=root
).decode()
# there is some unaccounted randomness so cannot compare the output with that of `generate/base.py`
assert "What food do llamas eat?" in tp_stdout
def test_cli():
args = ["litgpt", "generate_tp", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generation script that uses tensor parallelism" in output
================================================
FILE: tests/generate/utils.py
================================================
from collections import defaultdict
def find_forward_hooks(module):
mapping = defaultdict(list)
for name, submodule in module.named_modules():
for hook in submodule._forward_pre_hooks.values():
hook_data = ("forward_pre_hook", hook.func.__name__, hook.args, hook.keywords)
mapping[name].append(hook_data)
for hook in submodule._forward_hooks.values():
hook_data = ("forward_hook", hook.func.__name__, hook.args, hook.keywords)
mapping[name].append(hook_data)
return dict(mapping)
================================================
FILE: tests/test_adapter.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stdout
from copy import deepcopy
from dataclasses import asdict
from io import StringIO
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
import yaml
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
import litgpt.adapter as gpt_adapter
import litgpt.finetune.adapter as module
import litgpt.model as gpt
from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
from litgpt.utils import _RunIf
def test_config_identical():
name = "pythia-14m"
base_config = asdict(gpt.Config.from_name(name))
adapter_config = asdict(gpt_adapter.Config.from_name(name))
del adapter_config["adapter_prompt_length"]
del adapter_config["adapter_start_layer"]
assert adapter_config == base_config
with Fabric(accelerator="cpu").init_module(empty_init=True):
base_model = gpt.GPT.from_name(name)
adapter_model = gpt_adapter.GPT.from_name(name)
assert adapter_model.lm_head.weight.shape == base_model.lm_head.weight.shape
def test_adapter_filter(tmp_path):
fabric = Fabric(devices=1)
model = GPT.from_name("pythia-14m", n_layer=4)
save_path = tmp_path / "model.pth"
fabric.save(save_path, {"model": model}, filter={"model": adapter_filter})
saved = torch.load(save_path)["model"]
expected = {
"transformer.h.2.attn.adapter_wte.weight",
"transformer.h.2.attn.gating_factor",
"transformer.h.3.attn.adapter_wte.weight",
"transformer.h.3.attn.gating_factor",
}
assert set(saved) == expected
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
monkeypatch.setattr(module, "load_checkpoint", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
out_dir=out_dir,
precision="32-true",
train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),
eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),
)
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth.adapter",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.yaml",
}
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("(step)") == 6
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
assert logs.count("Final evaluation") == 1
assert "of trainable parameters: 168" in logs
def test_adapter_gpt_init_weights():
config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)
model = GPT(config)
param = model.transformer.h[0].attn.gating_factor
assert (param == 0).all()
torch.nn.init.constant_(param, 1.23)
assert (param != 0).any()
model.apply(model._init_weights)
assert (param == 0).all()
@_RunIf(dynamo=True)
@torch.inference_mode()
def test_adapter_compile():
model = GPT.from_name("pythia-14m", n_layer=3)
x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
explanation = torch._dynamo.explain(model)(x)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
model = GPT(model.config)
model.set_kv_cache(2)
input_pos = torch.arange(model.config.block_size)
explanation = torch._dynamo.explain(model)(x, input_pos)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
@_RunIf(min_cuda_gpus=1)
def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")
from bitsandbytes.optim import PagedAdamW
model_config = dict(
block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True
)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
monkeypatch.setattr(module, "load_checkpoint", Mock())
train_mock = Mock()
train_mock.return_value = {
"raw_tokens": 1000,
"raw_tokens_plus_prompt_template": 1100,
"raw_tokens_plus_prompt_template_and_padding": 1200,
}
monkeypatch.setattr(module, "fit", train_mock)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
precision="16-true",
quantize="bnb.nf4-dq",
out_dir=tmp_path,
)
_, kwargs = train_mock.call_args
fabric = kwargs["fabric"]
model = kwargs["model"]
optimizer = kwargs["optimizer"]
assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)
assert isinstance(optimizer, _FabricOptimizer)
assert isinstance(optimizer._optimizer, PagedAdamW)
dtype_to_name = {"torch.uint8": set(), "torch.float16": set()}
for name, layer in model.named_parameters():
name = name[len("_forward_module.") :]
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.0.attn.qkv.bias",
"transformer.h.0.attn.proj.bias",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.0.norm_2.weight",
"transformer.h.0.norm_2.bias",
"transformer.h.0.mlp.fc.bias",
"transformer.h.0.mlp.proj.bias",
"transformer.h.1.norm_1.weight",
"transformer.h.1.norm_1.bias",
"transformer.h.1.attn.gating_factor",
"transformer.h.1.attn.qkv.bias",
"transformer.h.1.attn.proj.bias",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.norm_2.weight",
"transformer.h.1.norm_2.bias",
"transformer.h.1.mlp.fc.bias",
"transformer.h.1.mlp.proj.bias",
"transformer.ln_f.weight",
"transformer.ln_f.bias",
},
"torch.uint8": {
"lm_head.weight",
"transformer.h.0.attn.qkv.weight",
"transformer.h.0.attn.proj.weight",
"transformer.h.0.mlp.fc.weight",
"transformer.h.0.mlp.proj.weight",
"transformer.h.1.attn.qkv.weight",
"transformer.h.1.attn.proj.weight",
"transformer.h.1.mlp.fc.weight",
"transformer.h.1.mlp.proj.weight",
},
}
assert {p.name for p in tmp_path.rglob("*.pth.adapter")} == {"lit_model.pth.adapter"}
state_dict = torch.load(tmp_path / "final" / "lit_model.pth.adapter")
assert len(state_dict) == 1
dtype_to_name = {"torch.float16": set()}
for name, layer in state_dict["model"].items():
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.0.attn.gating_factor",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.attn.gating_factor",
}
}
logs = stdout.getvalue()
assert "of trainable parameters: 168" in logs
assert "of non-trainable parameters: 1,888" in logs
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_2(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_2({}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_3(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)
def test_load_legacy_state_dict():
"""Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers."""
config = Config(
n_embd=32,
n_head=4,
head_size=8,
n_query_groups=4,
bias=True,
)
attention_1 = CausalSelfAttention(config=config, block_idx=0)
# make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`
# and make them interleaved
state_dict = deepcopy(attention_1.state_dict())
state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config)
state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config)
attention_2 = CausalSelfAttention(config=config, block_idx=0)
attention_2.load_state_dict(state_dict)
================================================
FILE: tests/test_adapter_v2.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stdout
from copy import deepcopy
from io import StringIO
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
import yaml
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import litgpt.config as config_module
import litgpt.finetune.adapter_v2 as module
from litgpt.adapter_v2 import GPT as AdapterV2GPT
from litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.model import GPT as BaseGPT
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
from litgpt.utils import _RunIf
def test_config_identical():
name = "pythia-14m"
with Fabric(accelerator="cpu").init_module(empty_init=True):
base_model = BaseGPT.from_name(name)
adapter_model = AdapterV2GPT.from_name(name)
assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias")
assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale")
assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias")
assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale")
def test_adapter_v2_filter(tmp_path):
fabric = Fabric(devices=1)
model = AdapterV2GPT.from_name("pythia-14m", n_layer=3)
save_path = tmp_path / "model.pth"
fabric.save(save_path, {"model": model}, filter={"model": adapter_filter})
saved = torch.load(save_path)["model"]
expected = {
"lm_head.adapter_bias",
"lm_head.adapter_scale",
"transformer.ln_f.bias",
"transformer.ln_f.weight",
"transformer.h.2.attn.adapter_wte.weight",
"transformer.h.2.attn.gating_factor",
}
for layer in range(3):
for param in (
"attn.qkv.adapter_bias",
"attn.qkv.adapter_scale",
"attn.proj.adapter_bias",
"attn.proj.adapter_scale",
"mlp.fc.adapter_bias",
"mlp.fc.adapter_scale",
"mlp.proj.adapter_bias",
"mlp.proj.adapter_scale",
"norm_1.bias",
"norm_1.weight",
"norm_2.bias",
"norm_2.weight",
):
expected.add(f"transformer.h.{layer}.{param}")
assert set(saved) == expected
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
monkeypatch.setattr(module, "load_checkpoint", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
out_dir=out_dir,
precision="32-true",
train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),
eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),
)
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth.adapter_v2",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.yaml",
}
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("(step)") == 6
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
assert logs.count("Final evaluation") == 1
assert "of trainable parameters: 552" in logs
def test_adapter_v2_gpt_init_weights():
config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)
model = AdapterV2GPT(config)
for param in (model.transformer.h[0].attn.gating_factor, model.lm_head.adapter_bias):
assert (param == 0).all()
torch.nn.init.constant_(param, 1.23)
assert (param != 0).any()
model.apply(model._init_weights)
assert (param == 0).all()
@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs])
def test_base_model_can_be_adapter_v2_loaded(name):
kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32}
base_model = BaseGPT.from_name(name, **kwargs)
base_model_state_dict = base_model.state_dict()
lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0)
keys = lora_model.load_state_dict(base_model_state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert adapter_filter(k, None)
@_RunIf(dynamo=True)
@torch.inference_mode()
def test_adapter_v2_compile():
model = AdapterV2GPT.from_name("pythia-14m", n_layer=3)
x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
explanation = torch._dynamo.explain(model)(x)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
model = AdapterV2GPT(model.config)
model.set_kv_cache(2)
input_pos = torch.arange(model.config.block_size)
explanation = torch._dynamo.explain(model)(x, input_pos)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
@torch.inference_mode()
def test_against_hf_mixtral():
device = torch.device("cpu")
dtype = torch.float32
ours_config = Config.from_name(
"Mixtral-8x7B-Instruct-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
n_expert=4,
)
T = 5
theirs_config = MixtralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
num_local_experts=ours_config.n_expert,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = MixtralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = AdapterV2GPT(ours_config).to(device)
# strict=False because missing keys due to adapter weights not contained in state dict
ours_model.load_state_dict(state_dict, strict=False)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = AdapterV2GPT(ours_config).to(device)
keys = ours_model.load_state_dict(state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert adapter_filter(k, None)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b"))
def test_against_original_gemma_2(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_2({}, state_dict, theirs_state_dict)
ours_model = AdapterV2GPT(ours_config).to(device)
keys = ours_model.load_state_dict(state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert adapter_filter(k, None)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(
# some macOS devices have numerical differences, hence the tol bump
ours_y,
theirs_y,
atol=1e-4,
rtol=1e-5,
)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
def test_against_original_gemma_3(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
rope_scaling={"factor": 8.0, "rope_type": "linear"},
rope_local_base_freq=ours_config.rope_local_base_freq,
)
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
ours_model = AdapterV2GPT(ours_config).to(device)
keys = ours_model.load_state_dict(state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert adapter_filter(k, None)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(
ours_y, theirs_y, rtol=3e-5, atol=3e-5
) # some macOS devices have numerical differences, hence the tol bump
@_RunIf(min_cuda_gpus=1)
def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")
from bitsandbytes.optim import PagedAdamW
model_config = dict(
block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True
)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
monkeypatch.setattr(module, "load_checkpoint", Mock())
train_mock = Mock()
train_mock.return_value = {
"raw_tokens": 1000,
"raw_tokens_plus_prompt_template": 1100,
"raw_tokens_plus_prompt_template_and_padding": 1200,
}
monkeypatch.setattr(module, "fit", train_mock)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
precision="16-true",
quantize="bnb.nf4-dq",
out_dir=tmp_path,
)
_, kwargs = train_mock.call_args
fabric = kwargs["fabric"]
model = kwargs["model"]
optimizer = kwargs["optimizer"]
assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)
assert isinstance(optimizer, _FabricOptimizer)
assert isinstance(optimizer._optimizer, PagedAdamW)
dtype_to_name = {"torch.uint8": set(), "torch.float16": set()}
for name, layer in model.named_parameters():
name = name[len("_forward_module.") :]
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.uint8": {
"transformer.h.0.mlp.fc.linear.weight",
"transformer.h.1.mlp.proj.linear.weight",
"transformer.h.1.attn.qkv.linear.weight",
"transformer.h.0.attn.proj.linear.weight",
"lm_head.linear.weight",
"transformer.h.1.attn.proj.linear.weight",
"transformer.h.0.mlp.proj.linear.weight",
"transformer.h.0.attn.qkv.linear.weight",
"transformer.h.1.mlp.fc.linear.weight",
},
"torch.float16": {
"transformer.h.1.attn.qkv.adapter_bias",
"transformer.h.1.mlp.proj.adapter_bias",
"transformer.h.0.attn.qkv.adapter_bias",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.qkv.linear.bias",
"transformer.h.1.attn.adapter_wte.weight",
"transformer.ln_f.weight",
"transformer.h.0.mlp.fc.linear.bias",
"transformer.h.0.mlp.proj.linear.bias",
"transformer.h.1.mlp.fc.linear.bias",
"transformer.h.0.attn.proj.adapter_scale",
"transformer.h.0.attn.qkv.adapter_scale",
"transformer.h.1.norm_2.bias",
"transformer.h.1.attn.proj.adapter_scale",
"transformer.h.0.norm_2.bias",
"transformer.h.0.mlp.fc.adapter_scale",
"transformer.h.0.attn.proj.linear.bias",
"transformer.h.1.attn.proj.linear.bias",
"transformer.h.1.norm_1.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.1.attn.proj.adapter_bias",
"transformer.h.0.mlp.proj.adapter_scale",
"transformer.h.0.mlp.proj.adapter_bias",
"transformer.h.1.mlp.fc.adapter_bias",
"transformer.h.1.mlp.proj.adapter_scale",
"transformer.h.1.attn.gating_factor",
"transformer.h.1.norm_1.weight",
"transformer.ln_f.bias",
"transformer.h.0.mlp.fc.adapter_bias",
"lm_head.adapter_scale",
"lm_head.adapter_bias",
"transformer.h.1.norm_2.weight",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.h.1.attn.qkv.adapter_scale",
"transformer.h.1.mlp.fc.adapter_scale",
"transformer.h.1.attn.qkv.linear.bias",
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.0.norm_2.weight",
"transformer.h.1.mlp.proj.linear.bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.0.attn.proj.adapter_bias",
},
}
assert {p.name for p in tmp_path.rglob("*.pth.adapter_v2")} == {"lit_model.pth.adapter_v2"}
state_dict = torch.load(tmp_path / "final" / "lit_model.pth.adapter_v2")
assert len(state_dict) == 1
dtype_to_name = {"torch.float16": set()}
for name, layer in state_dict["model"].items():
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.h.1.attn.adapter_wte.weight",
"transformer.h.1.attn.proj.adapter_bias",
"transformer.h.1.mlp.fc.adapter_scale",
"lm_head.adapter_bias",
"transformer.h.0.mlp.proj.adapter_scale",
"transformer.ln_f.bias",
"lm_head.adapter_scale",
"transformer.h.1.norm_2.weight",
"transformer.h.0.attn.qkv.adapter_scale",
"transformer.h.0.mlp.proj.adapter_bias",
"transformer.h.0.attn.gating_factor",
"transformer.h.1.norm_1.bias",
"transformer.h.1.mlp.fc.adapter_bias",
"transformer.h.1.mlp.proj.adapter_scale",
"transformer.h.0.mlp.fc.adapter_scale",
"transformer.h.1.attn.qkv.adapter_bias",
"transformer.h.0.norm_2.weight",
"transformer.h.1.norm_2.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.0.attn.proj.adapter_scale",
"transformer.h.1.mlp.proj.adapter_bias",
"transformer.h.0.attn.qkv.adapter_bias",
"transformer.h.0.attn.adapter_wte.weight",
"transformer.ln_f.weight",
"transformer.h.1.attn.gating_factor",
"transformer.h.0.mlp.fc.adapter_bias",
"transformer.h.1.attn.proj.adapter_scale",
"transformer.h.0.attn.proj.adapter_bias",
"transformer.h.0.norm_1.bias",
"transformer.h.0.norm_2.bias",
"transformer.h.1.norm_1.weight",
"transformer.h.1.attn.qkv.adapter_scale",
}
}
logs = stdout.getvalue()
assert "of trainable parameters: 552" in logs
assert "of non-trainable parameters: 1,808" in logs
def test_load_legacy_state_dict():
"""Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers."""
config = Config(
n_embd=32,
n_head=4,
head_size=8,
n_query_groups=4,
bias=True,
)
attention_1 = CausalSelfAttention(config=config, block_idx=0)
# make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`
# and make them interleaved
state_dict = deepcopy(attention_1.state_dict())
state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config)
state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config)
attention_2 = CausalSelfAttention(config=config, block_idx=0)
attention_2.load_state_dict(state_dict)
================================================
FILE: tests/test_api.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import re
import sys
from collections import OrderedDict
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from lightning.fabric.accelerators import CUDAAccelerator
from litgpt.api import LLM, benchmark_dict_to_markdown_table, calculate_number_of_devices
from litgpt.scripts.download import download_from_hub
from litgpt.utils import _RunIf
skip_in_ci_on_macos = pytest.mark.skipif(
sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.",
)
if sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true":
USE_MPS = False
elif torch.backends.mps.is_available():
USE_MPS = True
else:
USE_MPS = False
@pytest.fixture
def mock_llm():
llm = MagicMock(spec=LLM)
llm.model = MagicMock()
llm.preprocessor = MagicMock()
llm.prompt_style = MagicMock()
llm.checkpoint_dir = MagicMock()
llm.fabric = MagicMock()
return llm
def test_load_model(mock_llm):
assert isinstance(mock_llm, LLM)
assert mock_llm.model is not None
assert mock_llm.preprocessor is not None
assert mock_llm.prompt_style is not None
assert mock_llm.checkpoint_dir is not None
assert mock_llm.fabric is not None
def test_generate(mock_llm):
prompt = "What do Llamas eat?"
mock_llm.generate.return_value = prompt + " Mock output"
output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5)
assert isinstance(output, str)
assert len(output) > len(prompt)
def test_stream_generate(mock_llm):
prompt = "What do Llamas eat?"
def iterator():
outputs = (prompt + " Mock output").split()
yield from outputs
mock_llm.generate.return_value = iterator()
output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5, stream=True)
result = "".join([out for out in output])
assert len(result) > len(prompt)
def test_generate_token_ids(mock_llm):
prompt = "What do Llamas eat?"
mock_output_ids = MagicMock(spec=torch.Tensor)
mock_output_ids.shape = [len(prompt) + 10]
mock_llm.generate.return_value = mock_output_ids
output_ids = mock_llm.generate(prompt, max_new_tokens=10, return_as_token_ids=True)
assert isinstance(output_ids, torch.Tensor)
assert output_ids.shape[0] > len(prompt)
def test_calculate_number_of_devices():
assert calculate_number_of_devices(1) == 1
assert calculate_number_of_devices([0, 1, 2]) == 3
assert calculate_number_of_devices(None) == 0
def test_llm_load_random_init(tmp_path):
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
torch.manual_seed(123)
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(model="pythia-160m", init="random", tokenizer_dir=Path(tmp_path / "EleutherAI/pythia-14m"))
input_text = "some text text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15
# The following below tests that generate works with different prompt lengths
# after the kv cache was set
input_text = "some text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15
input_text = "some text text text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15
def test_llm_load_hub_init(tmp_path):
torch.manual_seed(123)
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained")
text_1 = llm.generate("text", max_new_tokens=10, top_k=1)
assert len(text_1) > 0
text_2 = llm.generate("text", max_new_tokens=10, top_k=1, stream=True)
text_2 = "".join(list(text_2))
assert text_1 == text_2, (text_1, text_2)
def test_model_not_initialized(tmp_path):
llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained", distribute=None)
s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model."
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("text")
llm = LLM.load(model="EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random", distribute=None)
s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model."
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("text")
@_RunIf(min_cuda_gpus=2)
def test_more_than_1_device_for_sequential_gpu(tmp_path):
device_count = CUDAAccelerator.auto_device_count()
if device_count <= 2:
model_name = "EleutherAI/pythia-14m"
else:
model_name = "EleutherAI/pythia-160m"
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model=model_name,
)
with pytest.raises(
NotImplementedError,
match="Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'.",
):
llm.distribute(devices=2)
llm.distribute(devices=2, generate_strategy="sequential")
assert isinstance(llm.generate("What do llamas eat?"), str)
assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0"
last_layer_idx = len(llm.model.transformer.h) - 1
assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == "cuda:1"
# Also check with default (devices="auto") setting
llm.distribute(generate_strategy="sequential")
assert isinstance(llm.generate("What do llamas eat?"), str)
assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0"
assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == f"cuda:{device_count - 1}"
@_RunIf(min_cuda_gpus=2)
@pytest.mark.skipif(bool(os.getenv("SKIP_WITH_CI")), reason="Skip this test in CI due to ...")
def test_more_than_1_device_for_tensor_parallel_gpu(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(model="EleutherAI/pythia-14m")
# this crashes the CI, maybe because of process forking; works fine locally though
llm.distribute(devices=2, generate_strategy="tensor_parallel")
assert isinstance(llm.generate("What do llamas eat?"), str)
@_RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("strategy", ("sequential", "tensor_parallel"))
@pytest.mark.xfail(
NotADirectoryError, reason="This test is expected to fail due to a NotADirectoryError.", strict=False
)
def test_sequential_tp_incompatibility_with_random_weights(strategy, tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(model="EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random")
with pytest.raises(
NotImplementedError,
match=re.escape(
"The LLM was initialized with init='random' but .distribute() currently only supports pretrained weights."
),
):
llm.distribute(devices=1, generate_strategy=strategy)
@pytest.mark.parametrize("strategy", ("sequential", "tensor_parallel"))
def test_sequential_tp_cpu(strategy, tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
distribute=None,
)
with pytest.raises(
NotImplementedError, match=f"generate_strategy='{strategy}' is only supported for accelerator='cuda'|'gpu'."
):
llm.distribute(devices=1, accelerator="cpu", generate_strategy=strategy)
def test_initialization_for_trainer(tmp_path):
llm = LLM.load(model="EleutherAI/pythia-14m", distribute=None)
s = "The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model."
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("hello world")
llm.trainer_setup()
llm.model.to(llm.preprocessor.device)
assert isinstance(llm.generate("hello world"), str)
@_RunIf(min_cuda_gpus=1)
def test_quantization_is_applied(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
llm.distribute(devices=1, quantize="bnb.nf4", precision="bf16-true")
strtype = str(type(llm.model.lm_head))
assert "NF4Linear" in strtype, strtype
@_RunIf(min_cuda_gpus=1)
def test_fixed_kv_cache(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
llm.distribute(devices=1, fixed_kv_cache_size=100)
# Request too many tokens
with pytest.raises(NotImplementedError, match="max_seq_length 512 needs to be >= 9223372036854775809"):
_ = llm.generate("hello world", max_new_tokens=2**63)
def test_invalid_accelerator(tmp_path):
llm = LLM.load(model="EleutherAI/pythia-14m", distribute=None)
with pytest.raises(ValueError, match="Invalid accelerator"):
llm.distribute(accelerator="invalid")
def test_returned_benchmark_dir(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
text, bench_d = llm.benchmark(prompt="hello world")
assert isinstance(bench_d["Inference speed in tokens/sec"], list)
assert len(bench_d["Inference speed in tokens/sec"]) == 1
assert isinstance(bench_d["Inference speed in tokens/sec"][0], float)
text, bench_d = llm.benchmark(prompt="hello world", stream=True)
assert isinstance(bench_d["Inference speed in tokens/sec"], list)
assert len(bench_d["Inference speed in tokens/sec"]) == 1
assert isinstance(bench_d["Inference speed in tokens/sec"][0], float)
text, bench_d = llm.benchmark(num_iterations=10, prompt="hello world", stream=True)
assert isinstance(bench_d["Inference speed in tokens/sec"], list)
assert len(bench_d["Inference speed in tokens/sec"]) == 10
assert isinstance(bench_d["Inference speed in tokens/sec"][0], float)
def test_benchmark_dict_to_markdown_table_single_values():
bench_d = {
"Inference speed in tokens/sec": [17.617540650112936],
"Seconds to first token": [0.6533610639999097],
"Seconds total": [1.4758019020000575],
"Tokens generated": [26],
"Total GPU memory allocated in GB": [5.923729408],
}
expected_output = (
"| Metric | Mean | Std Dev |\n"
"|-------------------------------------|-----------------------------|-----------------------------|\n"
"| Inference speed in tokens/sec | 17.62 | nan |\n"
"| Seconds to first token | 0.65 | nan |\n"
"| Seconds total | 1.48 | nan |\n"
"| Tokens generated | 26.00 | nan |\n"
"| Total GPU memory allocated in GB | 5.92 | nan |\n"
)
assert benchmark_dict_to_markdown_table(bench_d) == expected_output
def test_benchmark_dict_to_markdown_table_multiple_values():
bench_d_list = {
"Inference speed in tokens/sec": [
17.034547562152305,
32.8974175404589,
33.04784205046782,
32.445697744648584,
33.204480197756396,
32.64187570945661,
33.21232058140845,
32.69377798373551,
32.92351459309756,
32.48909032591177,
],
"Seconds to first token": [
0.7403525039999295,
0.022901020000063,
0.02335712100011733,
0.022969672000272112,
0.022788318000039,
0.02365505999978268,
0.02320190000000366,
0.022791139999753796,
0.022871761999795126,
0.023060415999680117,
],
"Seconds total": [
1.5263099829999192,
0.7903355929997815,
0.7867382069998712,
0.8013389080001616,
0.7830268640000213,
0.7965228539997042,
0.7828420160003589,
0.7952583520000189,
0.7897091279996857,
0.8002686360000553,
],
"Tokens generated": [26, 26, 26, 26, 26, 26, 26, 26, 26, 26],
"Total GPU memory allocated in GB": [
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
5.923729408,
],
}
expected_output = (
"| Metric | Mean | Std Dev |\n"
"|-------------------------------------|-----------------------------|-----------------------------|\n"
"| Inference speed in tokens/sec | 31.26 | 5.01 |\n"
"| Seconds to first token | 0.09 | 0.23 |\n"
"| Seconds total | 0.87 | 0.23 |\n"
"| Tokens generated | 26.00 | 0.00 |\n"
"| Total GPU memory allocated in GB | 5.92 | 0.00 |\n"
)
assert benchmark_dict_to_markdown_table(bench_d_list) == expected_output
def test_state_dict(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
assert isinstance(llm.state_dict(), OrderedDict)
assert llm.state_dict()["lm_head.weight"].shape == torch.Size([50304, 128])
def test_save_method(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
target_dir = "saved_model"
llm.save(target_dir)
expected_files = [
"config.json",
"generation_config.json",
"lit_model.pth",
"model_config.yaml",
"prompt_style.yaml",
"tokenizer_config.json",
"tokenizer.json",
]
files_in_directory = os.listdir(target_dir)
for file_name in expected_files:
assert file_name in files_in_directory, f"{file_name} is missing from {target_dir}"
def test_forward_method(tmp_path):
with patch("torch.backends.mps.is_available", return_value=USE_MPS):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
inputs = torch.ones(6, 128, dtype=torch.int64).to(next(llm.model.parameters()).device)
assert llm(inputs).shape == torch.Size([6, 128, 50304])
logits, loss = llm(inputs, target_ids=inputs)
assert logits.shape == torch.Size([6, 128, 50304])
assert isinstance(loss.item(), float)
@skip_in_ci_on_macos # The macOS CI machine segfaults here (it works fine locally though)
def test_precision_selection(tmp_path):
llm = LLM.load(model="EleutherAI/pythia-14m", init="pretrained")
llm.distribute(precision="16-true")
assert llm.model._forward_module.lm_head.weight.dtype == torch.float16, (
f"Expected float16, but got {llm.model._forward_module.lm_head.weight.dtype}"
)
================================================
FILE: tests/test_args.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import pytest
from litgpt.args import TrainArgs
def test_compute_warmup_iters():
# warmup disabled
train = TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=0)
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 0
# lr_warmup_steps and lr_warmup_fraction both are not allowed
with pytest.raises(ValueError, match="Can't provide both `--train.lr_warmup_fraction`"):
TrainArgs(lr_warmup_steps=1, lr_warmup_fraction=0.2)
# lr_warmup_fraction invalid range
with pytest.raises(ValueError, match=" must be between 0 and 1"):
TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=1.1)
# lr_warmup_steps
train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=0)
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 100
# lr_warmup_steps multiplied by accumulation factor
train.global_batch_size = 4
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 400
assert train.warmup_iters(devices=2, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 200
# lr_warmup_steps truncated by max iters
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=120, train_dataloader=range(10)) == 120
# lr_warmup_fraction
train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=0, lr_warmup_fraction=0.3)
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(100)) == 30
# lr_warmup_fraction truncated by max iters
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=20, train_dataloader=range(100)) == 20
# lr_warmup_fraction rounds up
assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(5)) == 2
================================================
FILE: tests/test_batch.py
================================================
import warnings
from pathlib import Path
import lightning as L
import pytest
import torch
import litgpt
from litgpt.api import GPT, LLM
from litgpt.generate.base import (
batched_generate_fn,
batched_next_token,
generate_fn,
next_token,
)
from litgpt.scripts.download import download_from_hub
from litgpt.utils import _RunIf
warnings.filterwarnings("ignore")
def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]:
L.seed_everything(42)
model_name = "microsoft/phi-2"
download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path)
llm: LLM = LLM.load(
model_name,
tokenizer_dir=Path(tmp_path / model_name),
init="random",
)
model: GPT = llm.model
model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length, device=device)
return llm, model
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires a GPU.")
def test_batched_equivalence(tmp_path):
model_name = "microsoft/phi-2"
download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path)
device = "cuda:0"
batch_size = 3
sample_kwargs = {"top_k": 1}
llm: LLM = LLM.load(
model_name,
tokenizer_dir=Path(tmp_path / model_name),
init="random",
)
model: GPT = llm.model
model.set_kv_cache(batch_size=1, max_seq_length=50, device=device)
input_pos_1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device)
input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device)
x = torch.tensor(
[43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],
device=device,
dtype=torch.int64,
)
batch_x1 = torch.stack([x] * batch_size, dim=0)
# Single token generation baseline
tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs)
tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs)
assert tok_1.ndim == 1
assert tok_2.ndim == 1
assert tok_1.size(0) == 1
assert tok_2.size(0) == 1
# Switch to batched generation
model.clear_kv_cache()
model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0")
toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs)
toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs)
assert toks_1.ndim == 2
assert toks_2.ndim == 2
assert toks_1.size(0) == batch_size
assert toks_2.size(0) == batch_size
# Assert that single and batched next token generation are equivalent
assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}"
assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}"
@_RunIf(min_cuda_gpus=1)
def test_simple_batch():
old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
config = litgpt.Config.from_name("microsoft/phi-2", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256)
with torch.device("cuda"):
m = litgpt.GPT(config).requires_grad_(False).eval()
x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]])
input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]])
x1 = torch.tensor([[1], [2]])
input_pos1 = torch.tensor([[4], [3]])
with torch.device("cuda"):
m.set_kv_cache(2)
outs0 = m(x0, input_pos0)
outs1 = m(x1, input_pos1)
with torch.device("cuda"):
m.set_kv_cache(1)
outs0_ref0 = m(x0[:1], input_pos0[0])
outs1_ref0 = m(x1[:1], input_pos1[0])
with torch.device("cuda"):
m.set_kv_cache(1)
outs0_ref1 = m(x0[1:], input_pos0[1])
outs1_ref1 = m(x1[1:], input_pos1[1])
outs0_ref = torch.cat([outs0_ref0, outs0_ref1])
outs1_ref = torch.cat([outs1_ref0, outs1_ref1])
print(outs0_ref - outs0)
print(outs0.shape)
torch.testing.assert_close(outs0, outs0_ref)
torch.testing.assert_close(outs1, outs1_ref)
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32
@_RunIf(min_cuda_gpus=1)
def test_batch_generate(tmp_path):
torch.use_deterministic_algorithms(True)
device = "cuda:0"
batch_size = 3
sample_kwargs = {"top_k": 1}
llm, model = create_llm(tmp_path, batch_size, 50, device)
batch_x = torch.tensor(
[
[43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],
[25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596],
[1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596],
],
device=device,
dtype=torch.int64,
)
# Generate tokens
tokens = []
for l in batched_generate_fn(
model,
prompts=batch_x,
max_returned_tokens=50,
sample_args=sample_kwargs,
include_prompt=True,
include_eos=False,
):
tokens.append([t.item() if t is not None else None for t in l])
def find_unique_stop(triplets):
# Initialize a dictionary to count all number occurrences
number_count = {}
# Count occurrences of each number across all positions
for triplet in triplets:
for num in triplet:
number_count[num] = number_count.get(num, 0) + 1
# Initialize lists to store unique numbers for each position
unique_first = []
unique_second = []
unique_third = []
# Check each triplet
for a, b, c in triplets:
if number_count[a] == 1:
unique_first.append(a)
if number_count[b] == 1:
unique_second.append(b)
if number_count[c] == 1:
unique_third.append(c)
import random # Seeded earlier
random.shuffle(unique_first)
random.shuffle(unique_second)
random.shuffle(unique_third)
return [unique_first[0], unique_second[0], unique_third[0]]
# Now that we know the randomly generated tokens, sample some tokens to stop each stream at.
stops = find_unique_stop(tokens[batch_x.size(1) :])
first_stream = [t[0] for t in tokens if t[0] is not None]
second_stream = [t[1] for t in tokens if t[1] is not None]
third_stream = [t[2] for t in tokens if t[2] is not None]
# Let's slice the streams at the stop tokens.
stop_idxes = [
first_stream.index(stops[0]),
second_stream.index(stops[1]),
third_stream.index(stops[2]),
]
# While we're at it, grab the last token that would be generated before stopping.
last_tokens = [
first_stream[stop_idxes[0] - 1],
second_stream[stop_idxes[1] - 1],
third_stream[stop_idxes[2] - 1],
]
for t in tokens:
print(t)
# Now we generate again, stopping early at the stop tokens.
tokens = []
for l in batched_generate_fn(
model,
prompts=batch_x,
max_returned_tokens=50,
stop_tokens=[(s,) for s in stops],
sample_args=sample_kwargs,
include_prompt=True,
include_eos=False,
):
tokens.append([t.item() if t is not None else None for t in l])
# Finally, assert that the streams are correct.
first_stream = [t[0] for t in tokens if t[0] is not None]
print(first_stream)
print(len(first_stream), stop_idxes[0])
assert len(first_stream) == stop_idxes[0]
assert first_stream[-1] == last_tokens[0]
second_stream = [t[1] for t in tokens if t[1] is not None]
print(second_stream)
print(len(second_stream), stop_idxes[1])
assert len(second_stream) == stop_idxes[1]
assert second_stream[-1] == last_tokens[1]
third_stream = [t[2] for t in tokens if t[2] is not None]
print(third_stream)
print(len(third_stream), stop_idxes[2])
assert len(third_stream) == stop_idxes[2]
assert third_stream[-1] == last_tokens[2]
torch.use_deterministic_algorithms(False)
# for t in llm.tokenizer.decode_stream([torch.tensor(i) for i in first_stream]):
# print(t, end="", flush=True)
# print()
@_RunIf(min_cuda_gpus=1)
def test_batch_generate_equivalence(tmp_path):
torch.use_deterministic_algorithms(True)
device = "cuda:0"
batch_size = 3
sample_kwargs = {"top_k": 1}
llm, model = create_llm(tmp_path, batch_size, 50, device)
batch_x = torch.tensor(
[
[43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],
[25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596],
[1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596],
],
device=device,
dtype=torch.int64,
)
# The other test tests the stop_tokens functionality much more exhaustively, we'll just generate and compare 50 tokens here.
batch_tokens = []
for l in batched_generate_fn(
model,
prompts=batch_x,
max_returned_tokens=50,
sample_args=sample_kwargs,
include_prompt=False,
include_eos=False,
):
batch_tokens.append([t.item() if t is not None else None for t in l])
first_stream = [t[0] for t in batch_tokens if t[0] is not None]
batch_size = 1
llm, model = create_llm(tmp_path, batch_size, 50, device)
tokens = []
for t in generate_fn(
model,
prompt=batch_x[0],
max_returned_tokens=50,
include_prompt=False,
include_eos=False,
**sample_kwargs,
):
if t.size(0) == 1:
tokens.append(t.item())
else:
tokens.extend(t.tolist())
torch.use_deterministic_algorithms(False)
# TODO: (apaz-cli) This consistency test doesn't actually work at the moment. It's inconsistent.
# The output is really close... Something is going on here. For the moment, maybe this is close enough?
# Enough at least that we can start prototyping.
print(first_stream)
print(tokens)
# assert first_stream == tokens
================================================
FILE: tests/test_chat.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import re
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from itertools import repeat
from pathlib import Path
from typing import Iterable, Iterator
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
import yaml
import litgpt.chat.base as chat
import litgpt.generate.base as generate
from litgpt import Config, Tokenizer
from litgpt.utils import auto_download_checkpoint, save_config
skip_in_ci_on_macos = pytest.mark.skipif(
sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.",
)
@pytest.mark.parametrize(
("generated", "stop_tokens", "expected"),
[
(repeat(1), (), [1] * 8),
([1, 2, 3, 0], ([0],), [1, 2, 3]),
([1, 2, 3, 0], ([9], [2, 4], [1, 2, 3, 0]), []),
([1, 2, 3, 0, 0], ([0, 0, 0], [0, 0]), [1, 2, 3]),
([3, 1, 2], ([1, 2], [3]), []),
([1, 2, 3, 0, 3, 2, 1, 0], ([4, 3, 2, 1], [2, 4]), [1, 2, 3, 0, 3, 2, 1, 0]),
],
)
def test_generate(monkeypatch, generated, stop_tokens, expected):
import lightning as L
L.seed_everything(1234)
input_idx = torch.tensor([5, 3])
max_returned_tokens = len(input_idx) + 8
model = MagicMock()
model.config.block_size = 100
model.max_seq_length = 100
it = iter(generated)
def multinomial(*_, **__):
out = next(it)
return torch.tensor([out])
monkeypatch.setattr(generate, "multinomial_num_samples_1", multinomial)
actual = chat.generate(model, input_idx, max_returned_tokens, stop_tokens=stop_tokens)
actual = list(actual)
assert len(actual) == len(expected), (actual, expected)
if not actual:
assert actual == expected, (actual, expected)
else:
for t in actual:
assert t.dtype == torch.long, t.dtype
actual_list = torch.cat(actual).tolist()
assert actual_list == expected, (actual_list, expected)
def test_decode():
checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m")
tokenizer = Tokenizer(checkpoint_dir)
text = (
"Hello World! This a bunch of text. Lorem ipsum dolor sit amet, "
"consectetur adipiscing elit, sed do eiusmod tempor incididunt "
"ut labore et dolore magna aliqua."
)
encoded: torch.Tensor = tokenizer.encode(text)
encoded_stream: Iterable[torch.Tensor] = torch.tensor_split(encoded, encoded.shape[0], dim=0)
decoded_stream: Iterator[str] = tokenizer.decode_stream(encoded_stream)
decoded: str = "".join(decoded_stream)
# Note that encoded and decoded text will not always be character for character identical.abs
# Indeed, sometimes it is not. But that tends to be because of special cases, and this is not
# one of those.
assert text == decoded, (text, decoded)
@skip_in_ci_on_macos
@patch("litgpt.chat.base.input")
@pytest.mark.parametrize("stop_iteration", [KeyboardInterrupt, ""])
def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, tensor_like):
# these values will be iteratively provided for each `input()` call
mocked_input.side_effect = ["Hello", stop_iteration]
config_path = fake_checkpoint_dir / "model_config.yaml"
config = {
"name": "Llama 3",
"block_size": 128,
"vocab_size": 50,
"n_layer": 2,
"n_head": 4,
"n_embd": 8,
"rotary_percentage": 1,
}
config_path.write_text(yaml.dump(config))
load_mock = Mock()
load_mock.return_value = load_mock
monkeypatch.setattr(chat, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.backend = "sentencepiece"
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode_stream.return_value = "foo bar baz"
monkeypatch.setattr(chat, "Tokenizer", tokenizer_mock)
generate_mock = MagicMock()
generate_mock.__iter__.return_value = [torch.tensor([3, 2, 1])]
monkeypatch.setattr(chat, "generate", generate_mock)
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.main(temperature=2.0, max_new_tokens=10, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)
# decoding is done per each generated item
assert len(tokenizer_mock.return_value.decode_stream.mock_calls) == 1
assert tokenizer_mock.return_value.decode_stream.call_args[0][0] is generate_mock.return_value # Now a Mock
# Assert that the generated result is printed to stdout
assert re.match(r".*Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL), out.getvalue()
def test_cli():
args = ["litgpt", "chat", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Chat with a model" in output
@skip_in_ci_on_macos
@patch("litgpt.chat.base.input")
@patch("litgpt.chat.base.merge_lora")
def test_merge_lora_if_needed(mocked_merge_lora, mocked_input, fake_checkpoint_dir, monkeypatch, tensor_like):
# these values will be iteratively provided for each `input()` call
mocked_input.side_effect = [""]
# pretend there is an unmerged LORA checkpoint
os.rename(fake_checkpoint_dir / "lit_model.pth", fake_checkpoint_dir / "lit_model.pth.lora")
mocked_merge_lora.side_effect = lambda _: Path(fake_checkpoint_dir / "lit_model.pth").touch()
config = Config.from_name("pythia-14m")
save_config(config, fake_checkpoint_dir)
monkeypatch.setattr(chat, "load_checkpoint", Mock())
monkeypatch.setattr(chat, "Tokenizer", Mock())
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.main(checkpoint_dir=fake_checkpoint_dir)
assert re.match(r".*Merging LoRA weights with the base model\..*", out.getvalue(), re.DOTALL)
mocked_merge_lora.assert_called_once()
@skip_in_ci_on_macos
def test_litgpt_chat_endtoend():
from litgpt.chat.base import main
checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m")
# Patch input() and redirect stdout. Raise to exit the repl.
simulated_input = Mock(side_effect=["input", KeyboardInterrupt])
captured_output = StringIO()
with patch("builtins.input", simulated_input):
with redirect_stdout(captured_output):
try:
main(checkpoint_dir=checkpoint_dir, max_new_tokens=256, top_k=1)
except KeyboardInterrupt:
pass
# pythia-14m is not instruct-tuned, so it does not give an "answer" per se, but a continuation.
output = captured_output.getvalue()
assert ">> Reply: " in output, f"Expected reply not found. Got:\n{output}"
# Verify the model actually generated some text after the reply prompt
reply_start = output.index(">> Reply: ") + len(">> Reply: ")
assert len(output[reply_start:].strip()) > 0, f"Expected non-empty reply. Got:\n{output}"
assert simulated_input.call_count == 2
@skip_in_ci_on_macos
def test_litgpt_generate_endtoend():
from litgpt.generate.base import main
checkpoint_dir = auto_download_checkpoint("EleutherAI/pythia-14m")
captured_output = StringIO()
with redirect_stdout(captured_output):
try:
main(checkpoint_dir=checkpoint_dir, prompt="Hello World", max_new_tokens=256, top_k=1)
except KeyboardInterrupt:
pass
# pythia-14m is not instruct-tuned, so it does not give an "answer" per se, but a continuation.
assert "Hello World!" in captured_output.getvalue(), (
f"Expected output not found. Got:\n{captured_output.getvalue()}"
)
================================================
FILE: tests/test_ci.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE
from litgpt.utils import _RunIf
@_RunIf(min_cuda_gpus=1)
def test_gpu_ci_installs_bitsandbytes():
assert _BITSANDBYTES_AVAILABLE, str(_BITSANDBYTES_AVAILABLE)
================================================
FILE: tests/test_cli.py
================================================
import sys
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
import pytest
from packaging.version import Version
from litgpt.__main__ import main
def test_cli():
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "-h"]):
main()
out = out.getvalue()
assert "usage: litgpt" in out
assert (
"{download,chat,finetune,finetune_lora,finetune_full,finetune_adapter,finetune_adapter_v2,"
"pretrain,generate,generate_full,generate_adapter,generate_adapter_v2,generate_sequentially,"
"generate_speculatively,generate_tp,convert_to_litgpt,convert_from_litgpt,convert_pretrained_checkpoint,"
"merge_lora,evaluate,serve}" in out
)
assert (
"""Available subcommands:
download Download weights or tokenizer data from the Hugging
Face Hub.
chat Chat with a model."""
in out
)
assert """evaluate Evaluate a model with the LM Evaluation Harness.""" in out
assert """serve Serve a LitGPT model using LitServe.""" in out
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "finetune_lora", "-h"]):
main()
out = out.getvalue()
assert (
"""--lora_alpha LORA_ALPHA
The LoRA alpha. (type: int, default: 16)"""
in out
)
if Version(f"{sys.version_info.major}.{sys.version_info.minor}") < Version("3.9"):
# python 3.8 prints `Union[int, null]` instead of `Optional[int]`
return
out = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out), mock.patch("sys.argv", ["litgpt", "pretrain", "-h"]):
main()
out = out.getvalue()
print(out)
assert (
"""--train.max_tokens MAX_TOKENS
Total number of tokens to train on (type:
Optional[int], default: 3000000000000)"""
in out
)
def test_pretrain_allows_max_steps():
# Ensure --train.max_steps is accepted by the CLI for pretrain
# and only emits a warning instead of raising a validation error.
args = [
"litgpt",
"pretrain",
"pythia-14m",
"--train.max_steps=1",
"--out_dir=out/test-cli",
]
with pytest.warns(UserWarning, match="max_steps"):
try:
with mock.patch("sys.argv", args):
main()
except Exception:
pass
def test_rewrite_finetune_command():
out1 = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch("sys.argv", ["litgpt", "fineune", "-h"]):
main()
out2 = StringIO()
with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch("sys.argv", ["litgpt", "fineune_lora", "-h"]):
main()
assert out1.getvalue() == out2.getvalue()
================================================
FILE: tests/test_config.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import pytest
import yaml
import litgpt.config as config_module
from litgpt import Config
from litgpt.config import find_multiple
def test_config():
config = Config()
assert config.name == ""
assert config.block_size == 4096
config = Config(block_size=2048)
assert config.block_size == 2048
config = Config.from_name("pythia-14m")
assert config.block_size == 512
config = Config.from_name("pythia-14m", block_size=4096)
assert config.block_size == 4096
config = Config(hf_config={"name": "pythia-14m"})
assert config.name == "pythia-14m"
def test_from_hf_name():
# by short-hand name
config0 = Config.from_name("tiny-llama-1.1b")
# or by huggingface hub repo name
config1 = Config.from_name("TinyLlama-1.1B-intermediate-step-1431k-3T")
assert config0 is not None
assert config1 is not None
assert config0 == config1
def test_nonexisting_name():
with pytest.raises(ValueError, match="'invalid-model-name' is not a supported config name"):
Config.from_name("invalid-model-name")
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
def test_short_and_hf_names_are_equal_unless_on_purpose(config):
# by short-hand name
config0 = Config.from_name(config["name"])
# or by huggingface hub repo name
config1 = Config.from_name(config["hf_config"]["name"])
assert config0.name == config1.name
def test_from_hf_name_with_org_string():
# Test case 1: valid input
config0 = Config.from_name("tiny-llama-1.1b")
config1 = Config.from_name("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
assert config0 is not None
assert config1 is not None
assert config0 == config1
# Test case 2: invalid input - org not found
with pytest.raises(
ValueError, match="'UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T' is not a supported config name"
):
Config.from_name("UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T")
# Test case 3: invalid input - name not found
with pytest.raises(ValueError, match="'TinyLlama/TinyLlama-XYZ' is not a supported config name"):
Config.from_name("TinyLlama/TinyLlama-XYZ")
def test_from_checkpoint(tmp_path):
# 1. Neither `lit_config.py` nor matching config exists.
with pytest.raises(FileNotFoundError, match="neither 'model_config.yaml' nor matching config exists"):
Config.from_checkpoint(tmp_path / "non_existing_checkpoint")
# 2. If `lit_config.py` doesn't exists, but there is a matching config in `litgpt/config.py`.
config = Config.from_checkpoint(tmp_path / "pythia-14m")
assert config.name == "pythia-14m"
assert config.block_size == 512
assert config.n_layer == 6
# 3. If only `lit_config.py` exists.
config_data = {"name": "pythia-14m", "block_size": 24, "n_layer": 2}
with open(tmp_path / "model_config.yaml", "w", encoding="utf-8") as file:
yaml.dump(config_data, file)
config = Config.from_checkpoint(tmp_path)
assert config.name == "pythia-14m"
assert config.block_size == 24
assert config.n_layer == 2
# 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config
(tmp_path / "pythia-14m").mkdir()
with open(tmp_path / "pythia-14m/model_config.yaml", "w", encoding="utf-8") as file:
yaml.dump(config_data, file)
config = Config.from_checkpoint(tmp_path / "pythia-14m")
assert config.name == "pythia-14m"
assert config.block_size == 24
assert config.n_layer == 2
@pytest.mark.parametrize("head_size", [None, 128])
def test_head_size(head_size):
config = Config(head_size)
assert config.head_size == head_size or config.n_embd // config.n_head
def test_find_multiple():
assert find_multiple(17, 5) == 20
assert find_multiple(30, 7) == 35
assert find_multiple(10, 2) == 10
assert find_multiple(5, 10) == 10
assert find_multiple(50254, 128) == 50304
assert find_multiple(50254, 256) == 50432
assert find_multiple(50254, 512) == 50688
================================================
FILE: tests/test_config_hub.py
================================================
import importlib
import importlib.util
from pathlib import Path
from unittest import mock
from unittest.mock import Mock
import pytest
from lightning.fabric.plugins import Precision
from litgpt import Config
from litgpt.utils import CLI
fixed_pairs = [
("litgpt/pretrain.py", "pretrain/debug.yaml"),
("litgpt/pretrain.py", "pretrain/tinyllama.yaml"),
("litgpt/pretrain.py", "pretrain/tinystories.yaml"),
(
"litgpt/pretrain.py",
"https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/tinystories.yaml",
),
]
config_hub_path = Path(__file__).parent.parent / "config_hub" / "finetune"
model_pairs = []
for model_dir in config_hub_path.iterdir():
if model_dir.is_dir():
model_name = model_dir.name
for yaml_file in model_dir.glob("*.yaml"):
config_name = yaml_file.stem
python_file = "litgpt/finetune/full.py" if config_name == "full" else "litgpt/finetune/lora.py"
relative_yaml_path = yaml_file.relative_to(config_hub_path.parent)
model_pairs.append((python_file, str(relative_yaml_path)))
all_pairs = fixed_pairs + model_pairs
@pytest.mark.parametrize(("script_file", "config_file"), all_pairs)
def test_config_help(script_file, config_file, monkeypatch):
"""Test that configs validate against the signature in the scripts."""
script_file = Path(__file__).parent.parent / script_file
assert script_file.is_file()
if "http" not in str(config_file):
config_file = Path(__file__).parent.parent / "config_hub" / config_file
assert config_file.is_file()
spec = importlib.util.spec_from_file_location(str(script_file.parent.name), script_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
monkeypatch.setattr(module, "main", Mock())
monkeypatch.setattr(module, "Tokenizer", Mock())
monkeypatch.setattr(module, "BitsandbytesPrecision", Mock(return_value=Precision()), raising=False)
monkeypatch.setattr(module, "Config", Mock(return_value=Config.from_name("pythia-14m")))
monkeypatch.setattr(module, "check_valid_checkpoint_dir", Mock(), raising=False)
try:
with mock.patch("sys.argv", [script_file.name, "--config", str(config_file), "--devices", "1"]):
CLI(module.setup)
module.main.assert_called_once()
except FileNotFoundError:
pass
# FileNotFound occurs here because we have not downloaded the model weights referenced in the config files
# which is ok because here we just want to validate the config file itself.
================================================
FILE: tests/test_deepseek_moe.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import pytest
import torch
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM
from litgpt import Config
from litgpt.model import GPT, LLaMAMLP
@torch.inference_mode()
@pytest.mark.parametrize("batch_size", (1, 2))
@pytest.mark.parametrize("seq_len", (8, 16))
@pytest.mark.parametrize("device", [torch.device("cpu")])
def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device):
"""Test MOE litgpt vs hf"""
config_litgpt = Config(
padded_vocab_size=10000,
n_layer=2,
vocab_size=10000,
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
norm_eps=1e-6,
bias=False,
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
n_expert=16,
n_shared_expert=1,
n_expert_per_token=2,
n_expert_groups=4,
n_topk_groups=2,
n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2`
first_k_dense_replace=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
moe_intermediate_size=20,
mlp_class_name="LLaMAMoE",
)
config_hf = DeepseekV3Config(
padded_vocab_size=10000,
num_hidden_layers=2,
vocab_size=10000,
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=16,
qk_rope_head_dim=8,
qk_nope_head_dim=8,
v_head_dim=16,
rope_interleave=False,
first_k_dense_replace=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
n_routed_experts=config_litgpt.n_expert,
n_shared_experts=config_litgpt.n_shared_expert,
num_experts_per_tok=config_litgpt.n_expert_per_token,
n_group=config_litgpt.n_expert_groups,
topk_group=config_litgpt.n_topk_groups,
moe_intermediate_size=config_litgpt.moe_intermediate_size,
)
model_litgpt = GPT(config_litgpt).to(device)
model_litgpt.apply(model_litgpt._init_weights)
mlp_litgpt = model_litgpt.transformer.h[0].mlp
assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1)
moe_litgpt = model_litgpt.transformer.h[1].mlp
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
moe_hf = model_hf.model.layers[1].mlp
moe_litgpt.eval()
moe_hf.eval()
sync_weights(moe_litgpt, moe_hf)
hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)
output_litgpt = moe_litgpt(hidden_states)
output_hf = moe_hf(hidden_states)
assert torch.allclose(output_litgpt, output_hf, atol=1e-5)
def sync_weights(litgpt_model, hf_model):
print("Synchronizing MoE weights...")
with torch.no_grad():
if hasattr(litgpt_model, "gate"):
if hasattr(litgpt_model.gate, "weight"):
hf_model.gate.weight.copy_(litgpt_model.gate.weight)
if hasattr(litgpt_model.gate, "e_score_correction_bias"):
hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias)
for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)):
hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight)
hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight)
hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight)
if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"):
hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight)
hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight)
hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight)
print("MoE weight synchronization complete.")
================================================
FILE: tests/test_distributed.py
================================================
import pytest
import torch
from lightning import Fabric
from litgpt.utils import _RunIf
@_RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize("strategy", ["ddp", "fsdp"])
def test_no_backward_sync(strategy):
fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy)
fabric.launch()
# account for sharding in the case of FSDP
out_features = 1 if "ddp" in strategy else fabric.world_size
model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device)
x = torch.randn(1, 1, device=fabric.device)
model = fabric.setup(model)
# 6 iters, 3 grad accumulation iters
for i, enabled in enumerate((True, True, False, True, True, False), 1):
x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32)
with fabric.no_backward_sync(model, enabled):
y = model(x)
fabric.backward(y.sum())
if not enabled:
# Math for the first 3 iters
#
# DistributedDataParallel
# (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9
# ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^
# rank0 rank1 allreduce
#
# thunder.distributed.ddp
# ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9
# ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^
# rank0 rank1 allreduce1 rank0 rank1 allreduce2
assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape
assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)
model.weight.grad = None
================================================
FILE: tests/test_evaluate.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import subprocess
from contextlib import redirect_stdout
from dataclasses import asdict
from io import StringIO
from unittest import mock
import pytest
import torch
import yaml
import litgpt.eval.evaluate as module
from litgpt import GPT, Config
from litgpt.scripts.download import download_from_hub
@pytest.mark.flaky(reruns=3)
def test_evaluate_script(tmp_path):
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
checkpoint_dir = tmp_path / "EleutherAI" / "pythia-14m"
ours_model = GPT(ours_config)
torch.save(ours_model.state_dict(), checkpoint_dir / "lit_model.pth")
with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]):
with pytest.raises(ValueError) as excinfo:
module.convert_and_evaluate(
checkpoint_dir,
out_dir=tmp_path / "out_dir",
device=None,
dtype=torch.float32,
limit=5,
tasks="logiqa",
batch_size=0, # Test for non-positive integer
)
assert "batch_size must be a positive integer, 'auto', or in the format 'auto:N'." in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
module.convert_and_evaluate(
checkpoint_dir,
out_dir=tmp_path / "out_dir",
device=None,
dtype=torch.float32,
limit=5,
tasks="logiqa",
batch_size="invalid", # Test for invalid string
)
assert "batch_size must be a positive integer, 'auto', or in the format 'auto:N'." in str(excinfo.value)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["eval/evaluate.py"]):
module.convert_and_evaluate(
checkpoint_dir,
out_dir=tmp_path / "out_dir",
device=None,
dtype=torch.float32,
limit=5,
tasks="logiqa",
batch_size=1, # Valid case
)
stdout = stdout.getvalue()
assert (tmp_path / "out_dir" / "results.json").is_file()
assert "logiqa" in stdout
assert "Metric" in stdout
assert "Loading checkpoint shards" not in stdout
def test_cli():
args = ["litgpt", "evaluate", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Evaluate a model with the LM Evaluation Harness" in output
================================================
FILE: tests/test_full.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import Mock
import torch
import yaml
import litgpt.finetune.full as module
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
monkeypatch.setattr(module, "load_checkpoint", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
out_dir = tmp_path / "out"
setup_args = (fake_checkpoint_dir,)
setup_kwargs = dict(
data=Alpaca(download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0),
out_dir=out_dir,
precision="32-true",
train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),
eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),
)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]):
module.setup(*setup_args, **setup_kwargs)
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert set(os.listdir(out_dir / checkpoint_dir)) == {
"lit_model.pth",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.yaml",
}
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("(step)") == 6
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
assert logs.count("Final evaluation") == 1
assert "of trainable parameters: 1,888" in logs
# Resume training and do 2 steps more
setup_kwargs["train"].max_steps = 8
setup_kwargs["resume"] = True
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]):
module.setup(*setup_args, **setup_kwargs)
logs = stdout.getvalue()
assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs
assert logs.count("(step)") == 2
assert out_dir / "step-000008" in set(out_dir.iterdir())
================================================
FILE: tests/test_generate_speculatively.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import re
import subprocess
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from unittest.mock import ANY, Mock, call
import pytest
import torch
import yaml
from torch import nn
import litgpt.generate.speculative_decoding as generate
from litgpt import GPT, Config
from litgpt.utils import _RunIf
def test_speculative_decoding_target_never_accepts_draft_tokens():
class DraftModel(nn.Module):
def forward(self, **kwargs):
return torch.tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C)
class TargetModel(nn.Module):
def forward(self, idx, **kwargs):
_, T = idx.shape
return torch.tensor([[0, 0, 0, 0, 0, 6, 7, 8, 9, 10]] * T, dtype=torch.float)[None, ...] # (B, T, C)
draft_model = DraftModel()
target_model = TargetModel()
token = torch.tensor([-1])
input_pos = torch.tensor([0])
sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent
output = generate.speculative_decoding(
draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs
)
# target model never accepts draft model's output, thus the output of the `speculative_decoding`
# is a single token sampled from the target model
assert len(output) == 1
assert output > 5
def test_speculative_decoding_target_always_accepts_draft_tokens():
class DraftModel(nn.Module):
def forward(self, **kwargs):
return torch.tensor([0, 0, 3, 4, 5, 6, 7, 8, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C)
class TargetModel(nn.Module):
def forward(self, idx, **kwargs):
_, T = idx.shape
return torch.tensor([[0, 0, 3, 4, 5, 6, 7, 8, 0, 0]] * T, dtype=torch.float)[None, ...] # (B, T, C)
draft_model = DraftModel()
target_model = TargetModel()
token = torch.tensor([-1])
input_pos = torch.tensor([0])
sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent
output = generate.speculative_decoding(
draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs
)
# target model always accepts draft model's output, thus the output of the `speculative_decoding`
# is 4 tokens (3 accepted draft tokens + 1 sampled from target model's output)
assert len(output) == 4
assert torch.all((output >= 3) & (output <= 8))
def test_speculative_decoding_target_sometimes_accepts_draft_tokens():
class DraftModel(nn.Module):
def forward(self, **kwargs):
return torch.tensor([0, 0, 3, 4, 10, 9, 7, 8, 0, 0], dtype=torch.float)[None, None, ...] # (B, T, C)
class TargetModel(nn.Module):
def forward(self, idx, **kwargs):
return torch.tensor(
[
[0, 0, 0, 0, 10, 9, 0, 0, 0, 0],
[0, 0, 0, 0, 10, 9, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 10],
],
dtype=torch.float,
)[None, ...] # (B, T, C)
draft_model = DraftModel()
target_model = TargetModel()
token = torch.tensor([-1])
input_pos = torch.tensor([0])
sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent
output = generate.speculative_decoding(
draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs
)
# target model accepts only 2 out of 3 draft model's output, thus the output of the `speculative_decoding`
# is 3 tokens (2 accepted draft tokens + 1 sampled from adjusted distribution)
assert len(output) == 3
assert torch.equal(output, torch.tensor([4, 4, 9]))
@pytest.mark.parametrize("max_seq_length", (10, 15, 20, 25))
@pytest.mark.parametrize("speculative_k", (1, 2, 3))
def test_generate(max_seq_length, speculative_k):
# create a prompt
T = 5
input_idx = torch.arange(0, T)
max_new_tokens = max_seq_length - T
# prepare models
draft_model = GPT(Config(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8))
target_model = GPT(Config(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16))
for model in (draft_model, target_model):
model.max_seq_length = max_seq_length
model.set_kv_cache(batch_size=1)
# generate tokens
out, acceptance_rate = generate.generate(
draft_model, target_model, input_idx, T + max_new_tokens, top_k=1, speculative_k=speculative_k
)
# validate
assert out.size(0) == T + max_new_tokens - 1, (out.size(0), T + max_new_tokens - 1)
assert 0.0 <= acceptance_rate <= 1.0
@_RunIf(min_cuda_gpus=1) # speculative decoding makes sense only on a GPU
def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
# prepare configs for draft and target models
draft_model_dir = fake_checkpoint_dir / "draft_model"
draft_model_dir.mkdir()
target_model_dir = fake_checkpoint_dir / "target_model"
target_model_dir.mkdir()
draft_model_config = dict(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8)
target_model_config = dict(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16)
(draft_model_dir / "model_config.yaml").write_text(yaml.dump(draft_model_config))
(target_model_dir / "model_config.yaml").write_text(yaml.dump(target_model_config))
# create empty files required for validation
for model_dir in (draft_model_dir, target_model_dir):
(model_dir / "tokenizer.json").touch()
(model_dir / "tokenizer_config.json").touch()
(model_dir / "lit_model.pth").touch()
# moke functions
module_mock = Mock()
module_mock.config.block_size = 128
load_mock = Mock()
load_mock.return_value = load_mock
monkeypatch.setattr(generate, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generated_tokens = torch.tensor([3, 2, 1])
acceptance_rate = 0.0
generate_mock.return_value = (generated_tokens, acceptance_rate)
monkeypatch.setattr(generate, "generate", generate_mock)
# do the sampling
num_samples = 2
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(
draft_model_checkpoint_dir=draft_model_dir,
target_model_checkpoint_dir=target_model_dir,
temperature=2.0,
top_k=2,
top_p=0.9,
num_samples=num_samples,
)
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value[0])
assert (
generate_mock.mock_calls
== [
call(
ANY,
ANY,
tensor_like,
53,
temperature=2.0,
top_k=2,
top_p=0.9,
stop_tokens=[tokenizer_mock.return_value.eos_id],
speculative_k=3,
)
]
* num_samples
)
expected_output = "foo bar baz\nAcceptance rate: 0.00%\n" * num_samples
# Allow for the config to be printed before the expected repeated strings.
pattern = rf".*^{re.escape(expected_output.strip())}$.*"
assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)
err_value = err.getvalue()
expected_parts = [
"'padded_vocab_size': 512",
"'n_layer': 2",
"'n_head': 4",
]
assert all(part in err_value for part in expected_parts)
def test_cli():
args = ["litgpt", "generate_speculatively", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Default generation option" in output
================================================
FILE: tests/test_lora.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stdout
from copy import deepcopy
from io import StringIO
from itertools import product
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
import yaml
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from torch.distributed.device_mesh import init_device_mesh
from torch.nn import functional as F
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import litgpt.config as config_module
import litgpt.finetune.lora as module
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.lora import GPT as LoRAGPT
from litgpt.lora import (
CausalSelfAttention,
Config,
LoRALinear,
LoRAQKVLinear,
lora_filter,
mark_only_lora_as_trainable,
merge_lora_weights,
)
from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention
from litgpt.model import GPT as BaseGPT
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
from litgpt.utils import _RunIf
def test_lora_layer_replacement():
config = Config(n_layer=2, n_head=4, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1)
model = LoRAGPT(config)
assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention)
assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention)
assert isinstance(model.lm_head, LoRALinear)
assert isinstance(model.transformer.h[0].mlp.proj, LoRALinear)
def test_lora_merge():
config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
lora_r=8,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
lora_projection=True,
)
model = LoRAGPT(config)
model.train()
attn_proj = model.transformer.h[0].attn.proj
initial_weight = attn_proj.linear.weight.clone()
assert torch.equal(attn_proj.linear.weight, initial_weight)
# perform an update to the LoRA weights
mark_only_lora_as_trainable(model)
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64))
y.sum().backward()
optimizer.step()
optimizer.zero_grad()
# the weight remains unchanged (only lora A and B change)
assert torch.equal(attn_proj.linear.weight, initial_weight)
# calling merge() multiple times in a row should not merge multiple times
merge_lora_weights(model)
assert attn_proj.merged
weight_after = attn_proj.linear.weight.clone()
merge_lora_weights(model)
merge_lora_weights(model)
assert torch.equal(attn_proj.linear.weight, weight_after)
# check that `W_after = W_initial + (A x B)`
delta_w = attn_proj.get_lora_AB()
torch.testing.assert_close(weight_after, initial_weight + delta_w)
def test_lora_mqa_gqa():
# MHA
config = Config(
n_layer=1,
n_head=4,
n_embd=8,
block_size=1,
vocab_size=1,
lora_r=2,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
)
assert config.n_query_groups == config.n_head
model = LoRAGPT(config)
attn = model.transformer.h[0].attn.qkv
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)
bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(24)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0
# MQA
config.n_query_groups = 1
model = LoRAGPT(config)
attn = model.transformer.h[0].attn.qkv
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)
bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(12)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0
# GQA
config.n_query_groups = 2
model = LoRAGPT(config)
attn = model.transformer.h[0].attn.qkv
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)
bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(16)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0
@pytest.mark.parametrize(
"n_head, n_query_groups, enable_lora",
[
(4, 2, (True, False, True)), # GQA: Q+V only
(4, 1, (False, True, True)), # MQA: K+V only
(4, 2, (True, True, False)), # GQA: Q+K only
(8, 2, (True, True, True)), # GQA: all enabled, different ratio
(4, 4, (False, False, True)), # MHA: V only
],
)
def test_lora_ind_correctness(n_head, n_query_groups, enable_lora):
"""Verify lora_ind correctly partitions Q, K, V regions using head_size-based sizes."""
n_embd = 16
config = Config(
n_layer=1,
n_head=n_head,
n_embd=n_embd,
block_size=1,
vocab_size=1,
lora_r=2,
lora_alpha=8,
lora_dropout=0.0,
lora_query=enable_lora[0],
lora_key=enable_lora[1],
lora_value=enable_lora[2],
n_query_groups=n_query_groups,
)
model = LoRAGPT(config)
attn = model.transformer.h[0].attn.qkv
head_size = n_embd // n_head
q_size = head_size * n_head
kv_size = head_size * n_query_groups
expected_ind = []
if enable_lora[0]:
expected_ind.extend(range(0, q_size))
if enable_lora[1]:
expected_ind.extend(range(q_size, q_size + kv_size))
if enable_lora[2]:
expected_ind.extend(range(q_size + kv_size, q_size + 2 * kv_size))
assert torch.equal(attn.lora_ind, torch.tensor(expected_ind))
# Verify zero_pad output dimension matches full QKV size
total_qkv = q_size + 2 * kv_size
lora_out_dim = sum(attn.qkv_shapes)
x = torch.randn(1, 1, lora_out_dim)
assert attn.zero_pad(x).shape[-1] == total_qkv
def test_lora_filter(tmp_path):
fabric = Fabric(devices=1)
model = LoRAGPT.from_name("pythia-14m", n_layer=3, lora_r=1, lora_query=True, lora_value=True)
save_path = tmp_path / "model.pth"
fabric.save(save_path, {"model": model}, filter={"model": lora_filter})
saved = torch.load(save_path)["model"]
expected = {
"transformer.h.1.attn.qkv.lora_B",
"transformer.h.2.attn.qkv.lora_B",
"transformer.h.2.attn.qkv.lora_A",
"transformer.h.1.attn.qkv.lora_A",
"transformer.h.0.attn.qkv.lora_A",
"transformer.h.0.attn.qkv.lora_B",
}
assert set(saved) == expected
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
out_dir=out_dir,
precision="32-true",
train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),
eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),
)
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
"lit_model.pth.lora",
"model_config.yaml",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
"prompt_style.yaml",
}
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()
logs = stdout.getvalue()
assert logs.count("(step)") == 6
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
assert logs.count("Final evaluation") == 1
assert "of trainable parameters: 512" in logs
def test_lora_init_when_linear_overridden():
class MyLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
# this needs to be implemented to demonstrate the failure
super().__init__(*args, **kwargs)
original_linear = torch.nn.Linear
# Our bnb does this sort of monkey patching
torch.nn.Linear = MyLinear
layer = LoRAQKVLinear(1, 1, 1, 1, 1)
assert isinstance(layer.linear, original_linear)
torch.nn.Linear = original_linear
@pytest.mark.parametrize(
("apply_to", "target_layer_names", "mlp_class_name"),
(
("lora_projection", "transformer.h.0.attn.proj", "GptNeoxMLP"),
("lora_mlp", {"transformer.h.0.mlp.fc", "transformer.h.0.mlp.proj"}, "GptNeoxMLP"),
("lora_head", "lm_head", "GptNeoxMLP"),
("lora_projection", "transformer.h.0.attn.proj", "LLaMAMLP"),
("lora_mlp", {"transformer.h.0.mlp.fc_1", "transformer.h.0.mlp.fc_2", "transformer.h.0.mlp.proj"}, "LLaMAMLP"),
("lora_head", "lm_head", "LLaMAMLP"),
),
)
def test_lora_linear_utilization(apply_to, target_layer_names, mlp_class_name):
config = Config(
n_layer=1,
n_head=4,
n_embd=8,
block_size=1,
vocab_size=1,
lora_r=2,
lora_alpha=8,
lora_dropout=0.1,
mlp_class_name=mlp_class_name,
intermediate_size=8 * 3,
**{apply_to: True},
)
model = LoRAGPT(config)
state_dict = model.state_dict()
if isinstance(target_layer_names, str):
target_layer_names = {target_layer_names}
lora_sublayers = (".lora_A", ".lora_B")
# check that all the target layers have LoRA weights
for layer_name in target_layer_names:
for lora_sublayer in lora_sublayers:
assert layer_name + lora_sublayer in state_dict
# check that only target layers have LoRA weights
lora_params = [k for k in state_dict if k.endswith(lora_sublayers)]
lora_params = {k[:-7] for k in lora_params}
assert lora_params == target_layer_names
@torch.inference_mode()
@pytest.mark.parametrize(
"apply_to", (None, "lora_query", "lora_key", "lora_value", "lora_projection", "lora_mlp", "lora_head")
)
def test_lora_gpt_apply_lora_forward_no_exception(apply_to):
config = Config(n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1)
if apply_to:
setattr(config, apply_to, True)
input_ids = torch.tensor([[1]])
model = LoRAGPT(config)
model.eval()
model(input_ids)
@torch.inference_mode()
@pytest.mark.parametrize("n_query_groups", (1, 2, 3, 6))
@pytest.mark.parametrize("apply_to", product((False, True), repeat=3))
def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, apply_to):
keys = ("lora_query", "lora_key", "lora_value")
values = apply_to
apply_to = dict(zip(keys, values))
config = Config(
n_layer=1,
n_head=6,
n_embd=12,
block_size=1,
vocab_size=1,
lora_r=2,
lora_alpha=8,
lora_dropout=0.1,
n_query_groups=n_query_groups,
**apply_to,
)
model = LoRAGPT(config)
merge_lora_weights(model)
input_ids = torch.tensor([[1]])
model(input_ids)
@torch.inference_mode()
@pytest.mark.parametrize("head_size", (1, 2, 4))
@pytest.mark.parametrize("n_head", (1, 2, 3, 6, 12))
@pytest.mark.parametrize(
"enable_lora",
[
(False, False, True),
(False, True, False),
(False, True, True),
(True, False, False),
(True, False, True),
(True, True, False),
(True, True, True),
],
)
def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora):
C = 12
layer = LoRAQKVLinear(
C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora
)
x = torch.randn((1, 1, C))
a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A
b = layer.lora_B.data.unsqueeze(-1)
# original PyTorch conv1d function output
conv1d_pytorch = F.conv1d(a, b, groups=sum(layer.enable_lora))
# custom conv1d
conv1d_custom = layer.conv1d(a, b)
# custom conv1d forced to split, apply and concat tensors
layer.n_head = layer.n_query_groups + 1
conv1d_custom_forced = layer.conv1d(a, b)
assert torch.allclose(conv1d_pytorch, conv1d_custom)
assert torch.allclose(conv1d_pytorch, conv1d_custom_forced)
@pytest.mark.parametrize(("rank", "expected_merged"), ((0, False), (1, True)))
def test_lora_linear_weights_merged_status(rank, expected_merged):
layer = LoRALinear(10, 10, r=rank)
assert not layer.merged
layer.merge()
assert layer.merged == expected_merged
@pytest.mark.parametrize(
("rank", "enable_lora", "expected_merged"),
((0, True, False), (1, True, True), (0, False, False), (1, False, False)),
)
def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged):
C = 10
layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
assert not layer.merged
layer.merge()
assert layer.merged == expected_merged
@_RunIf(min_cuda_gpus=1)
def test_lora_merge_with_bitsandbytes():
if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")
import bitsandbytes as bnb
config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
lora_r=8,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
lora_projection=True,
)
fabric = Fabric(devices=1, plugins=BitsandbytesPrecision("nf4", dtype=torch.bfloat16, ignore_modules={"lm_head"}))
model = LoRAGPT(config)
mark_only_lora_as_trainable(model)
from bitsandbytes.optim import PagedAdamW
optimizer = PagedAdamW(model.parameters(), lr=1.0)
model, optimizer = fabric.setup(model, optimizer)
model.train()
attn_proj = model.transformer.h[0].attn.proj
initial_weight = attn_proj.linear.weight.clone()
initial_weight_kwargs = attn_proj.linear.weight.__dict__
# this was skipped
assert model.lm_head.linear.weight.dtype is torch.float32
assert attn_proj.linear.weight.dtype is torch.uint8
# perform an update to the LoRA weights
y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device))
loss = y.sum()
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
# the weight remains unchanged (only lora A and B change)
assert torch.equal(attn_proj.linear.weight, initial_weight)
# calling merge() multiple times in a row should not merge multiple times
merge_lora_weights(model)
assert attn_proj.merged
weight_after = attn_proj.linear.weight.clone()
merge_lora_weights(model)
merge_lora_weights(model)
assert torch.equal(attn_proj.linear.weight, weight_after)
# check that `W_after = W_initial + (A x B)`
delta_w = attn_proj.get_lora_AB()
# dequantize initial weight and sum with delta_w
initial_weight_data = (
bnb.functional.dequantize_4bit(initial_weight.data, initial_weight_kwargs["quant_state"]) + delta_w
)
# quantize again
initial_weight_data = bnb.nn.Params4bit(
initial_weight_data.to("cpu"), requires_grad=False, **initial_weight_kwargs
).to(initial_weight.device)
torch.testing.assert_close(weight_after, initial_weight_data)
def test_lora_gpt_init_weights():
config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_head=True)
model = LoRAGPT(config)
param = model.lm_head.lora_B.data
assert (param == 0).all()
torch.nn.init.constant_(param, 1.23)
assert (param != 0).any()
model.apply(model._init_weights)
assert (param == 0).all()
@pytest.mark.parametrize("name", [c["name"] for c in config_module.configs])
def test_base_model_can_be_lora_loaded(name):
kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32}
base_model = BaseGPT.from_name(name, **kwargs)
base_model_state_dict = base_model.state_dict()
lora_model = LoRAGPT.from_name(
name,
**kwargs,
lora_r=1,
lora_query=True,
lora_key=True,
lora_value=True,
lora_projection=True,
lora_mlp=True,
lora_head=True,
)
keys = lora_model.load_state_dict(base_model_state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert lora_filter(k, None)
@_RunIf(dynamo=True)
@torch.inference_mode()
def test_lora_compile():
model = LoRAGPT.from_name(
"pythia-14m",
n_layer=3,
lora_r=8,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_key=True,
lora_value=True,
lora_projection=True,
lora_mlp=True,
lora_head=True,
)
x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
explanation = torch._dynamo.explain(model)(x)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
model = LoRAGPT(model.config)
model.set_kv_cache(2)
input_pos = torch.arange(model.config.block_size)
explanation = torch._dynamo.explain(model)(x, input_pos)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
@torch.inference_mode()
def test_against_hf_mixtral():
device = torch.device("cpu")
dtype = torch.float32
ours_config = Config.from_name(
"Mixtral-8x7B-Instruct-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
n_expert=4,
lora_r=1,
lora_key=True,
lora_value=True,
)
T = 5
theirs_config = MixtralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
num_local_experts=ours_config.n_expert,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = MixtralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = LoRAGPT(ours_config).to(device)
keys = ours_model.load_state_dict(state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert lora_filter(k, None)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(
model_name,
n_layer=2,
n_head=16,
n_embd=32,
head_size=4,
intermediate_size=86,
lora_r=1,
lora_query=True,
lora_key=True,
lora_value=True,
)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = LoRAGPT(ours_config).to(device)
keys = ours_model.load_state_dict(state_dict, strict=False)
assert not keys.unexpected_keys
for k in keys.missing_keys:
assert lora_filter(k, None)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b"))
def test_against_original_gemma_2(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_2({}, state_dict, theirs_state_dict)
ours_model = LoRAGPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)
@torch.inference_mode()
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
def test_against_original_gemma_3(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
ours_model = LoRAGPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
@_RunIf(min_cuda_gpus=1)
def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
if not _BITSANDBYTES_AVAILABLE:
pytest.skip("BNB not available")
from bitsandbytes.optim import PagedAdamW
model_config = dict(
block_size=128,
n_layer=2,
n_embd=8,
n_head=4,
padded_vocab_size=8,
bias=True,
lora_r=8,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
lora_projection=True,
)
(fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config))
tokenizer_mock = Mock()
tokenizer_mock.return_value = tokenizer_mock
tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])
monkeypatch.setattr(module, "Tokenizer", tokenizer_mock)
monkeypatch.setattr(module, "load_checkpoint", Mock())
monkeypatch.setattr(module, "merge_lora", Mock())
train_mock = Mock()
train_mock.return_value = {
"raw_tokens": 1000,
"raw_tokens_plus_prompt_template": 1100,
"raw_tokens_plus_prompt_template_and_padding": 1200,
}
monkeypatch.setattr(module, "fit", train_mock)
stdout = StringIO()
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py", str(fake_checkpoint_dir)]):
module.setup(
fake_checkpoint_dir,
data=Alpaca(
download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0
),
out_dir=tmp_path,
precision="16-true",
quantize="bnb.nf4-dq",
)
_, kwargs = train_mock.call_args
fabric = kwargs["fabric"]
model = kwargs["model"]
optimizer = kwargs["optimizer"]
model.transformer.wte = model.transformer.wte.half()
assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)
assert isinstance(optimizer, _FabricOptimizer)
assert isinstance(optimizer._optimizer, PagedAdamW)
dtype_to_name = {"torch.uint8": set(), "torch.float16": set()}
for name, layer in model.named_parameters():
name = name[len("_forward_module.") :]
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.uint8": {
"transformer.h.0.attn.qkv.linear.weight",
"transformer.h.0.attn.proj.linear.weight",
"transformer.h.0.mlp.fc.linear.weight",
"transformer.h.1.mlp.proj.linear.weight",
"transformer.h.0.mlp.proj.linear.weight",
"transformer.h.1.attn.qkv.linear.weight",
"lm_head.linear.weight",
"transformer.h.1.attn.proj.linear.weight",
"transformer.h.1.mlp.fc.linear.weight",
},
"torch.float16": {
"transformer.h.0.attn.qkv.lora_B",
"transformer.h.0.norm_2.weight",
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.1.mlp.fc.linear.bias",
"transformer.ln_f.bias",
"transformer.h.1.attn.qkv.lora_B",
"transformer.h.1.attn.proj.linear.bias",
"transformer.h.1.norm_1.weight",
"transformer.h.1.attn.qkv.linear.bias",
"transformer.h.1.attn.qkv.lora_A",
"transformer.h.1.norm_1.bias",
"transformer.h.1.norm_2.bias",
"transformer.h.0.attn.proj.linear.bias",
"transformer.h.0.norm_1.bias",
"transformer.h.0.mlp.proj.linear.bias",
"transformer.h.0.mlp.fc.linear.bias",
"transformer.h.0.norm_2.bias",
"transformer.ln_f.weight",
"transformer.h.0.attn.qkv.lora_A",
"transformer.h.1.norm_2.weight",
"transformer.h.1.mlp.proj.linear.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.0.attn.qkv.linear.bias",
},
}
assert {p.name for p in tmp_path.rglob("*.lora")} == {"lit_model.pth.lora"}
state_dict = torch.load(tmp_path / "final" / "lit_model.pth.lora")
assert len(state_dict) == 1
dtype_to_name = {"torch.float16": set()}
for name, layer in state_dict["model"].items():
dtype_to_name[str(layer.dtype)].add(name)
assert dtype_to_name == {
"torch.float16": {
"transformer.h.1.attn.qkv.lora_A",
"transformer.h.0.attn.qkv.lora_A",
"transformer.h.0.attn.qkv.lora_B",
"transformer.h.1.attn.qkv.lora_B",
}
}
logs = stdout.getvalue()
assert "of trainable parameters: 512" in logs
assert "of non-trainable parameters: 1,888" in logs
@_RunIf(standalone=True, min_cuda_gpus=2)
def test_lora_model_fsdp_init():
config = Config(
n_layer=1,
n_head=2,
n_embd=8,
block_size=8,
vocab_size=8,
lora_r=8,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=False,
lora_projection=True,
)
fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
fabric.launch()
with fabric.init_module(empty_init=True):
model = LoRAGPT(config)
x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device)
model = fabric.setup(model)
y = model(x)
assert y.shape == torch.Size([2, 8, 512])
# verify that all the parameters, buffers and other attributes aren't on `meta` device
for m in model.modules():
for p_name, parameter in m.named_parameters():
assert not parameter.is_meta, f"Parameter `{p_name}` isn't materialized."
for b_name, buffer in m._buffers.items():
assert not buffer.is_meta, f"Buffer `{b_name}` isn't materialized."
for attr_name, attr_value in m.__dict__.items():
if isinstance(attr_value, torch.Tensor):
assert not attr_value.is_meta, f"Attribute `{attr_name}` isn't materialized."
def test_zero_pad_cpu_and_mocked_mps():
head_size = 64
n_head = 12
n_query_groups = 3
in_features = 128
q_size = head_size * n_head
kv_size = head_size * n_query_groups
out_features = q_size + 2 * kv_size
enable_lora = [True, False, True]
r = 4
model = LoRAQKVLinear(
in_features=in_features,
out_features=out_features,
head_size=head_size,
n_head=n_head,
n_query_groups=n_query_groups,
r=r,
enable_lora=enable_lora,
)
batch_size = 64
seq_len = 64
# embed_dim = sum of enabled qkv shapes: Q (q_size) + V (kv_size)
embed_dim = q_size + kv_size
x = torch.randn(batch_size, seq_len, embed_dim)
result_cpu = model.zero_pad(x)
with mock.patch("torch.backends.mps.is_available", return_value=True):
with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device:
mock_device.return_value = torch.device("mps")
result_mps = model.zero_pad(x)
assert result_cpu.shape == result_mps.shape, "Shape mismatch between CPU and MPS"
assert torch.allclose(result_cpu, result_mps), "Tensor values mismatch between CPU and MPS"
def test_load_legacy_state_dict():
"""Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers."""
config = Config(
n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, lora_r=8, lora_alpha=16, lora_dropout=0.1
)
attention_1 = CausalSelfAttention(config=config, block_idx=0)
# make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`
# and make them interleaved
state_dict = deepcopy(attention_1.state_dict())
state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config)
state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config)
attention_2 = CausalSelfAttention(config=config, block_idx=0)
attention_2.load_state_dict(state_dict)
@_RunIf(standalone=True, min_cuda_gpus=2)
def test_parallelize_fn():
from litgpt.finetune.lora import parallelize_fn
config = Config(
n_layer=2,
n_head=4,
n_embd=32,
block_size=8,
vocab_size=8,
lora_r=4,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
lora_projection=True,
)
fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
fabric.launch()
model = LoRAGPT(config)
mark_only_lora_as_trainable(model)
# create device mesh for data parallel
device_mesh = init_device_mesh(
device_type=fabric.device.type,
mesh_shape=(2, 1),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)
# test with activation checkpointing enabled (default)
parallelized_model = parallelize_fn(model, device_mesh, activation_checkpointing=True)
# verify the model is still functional
assert parallelized_model is not None
assert isinstance(parallelized_model, LoRAGPT)
parallelized_model = parallelized_model.to(fabric.device)
# test forward pass to ensure the parallelized model works
x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device)
# verify forward pass works
with torch.no_grad():
output = parallelized_model(x)
assert output.shape == (1, config.block_size, config.padded_vocab_size)
# test with activation checkpointing disabled
model_no_checkpoint = LoRAGPT(config)
mark_only_lora_as_trainable(model_no_checkpoint)
parallelized_model_no_checkpoint = parallelize_fn(model_no_checkpoint, device_mesh, activation_checkpointing=False)
# verify the model is still functional
assert parallelized_model_no_checkpoint is not None
assert isinstance(parallelized_model_no_checkpoint, LoRAGPT)
# test forward pass to ensure the parallelized model works
parallelized_model_no_checkpoint = parallelized_model_no_checkpoint.to(fabric.device)
with torch.no_grad():
output = parallelized_model_no_checkpoint(x)
assert output.shape == (1, config.block_size, config.padded_vocab_size)
# verify that all parameters are properly distributed (not on meta device)
for mod in parallelized_model.modules():
for param_name, param in mod.named_parameters():
if param.requires_grad: # Only check trainable parameters (LoRA parameters)
assert not param.is_meta, f"Parameter `{param_name}` should not be on meta device"
assert param.device.type == "cuda", f"Parameter `{param_name}` should be on CUDA device"
@_RunIf(standalone=True, min_cuda_gpus=2)
def test_load_from_full_model_state_dict():
from litgpt.finetune.lora import parallelize_fn
from litgpt.utils import load_from_full_model_state_dict
config = Config(
n_layer=2,
n_head=4,
n_embd=32,
block_size=8,
vocab_size=8,
lora_r=4,
lora_alpha=8,
lora_dropout=0.1,
lora_query=True,
lora_value=True,
lora_projection=True,
lora_mlp=True,
lora_head=True,
)
# set up distributed environment with FSDP
fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
fabric.launch()
# create a reference model to get the full state dict
reference_model = LoRAGPT(config)
mark_only_lora_as_trainable(reference_model)
# initialize the reference model with some values
with torch.no_grad():
for param in reference_model.parameters():
if param.requires_grad:
param.fill_(0.1)
# get the full state dict (simulating a checkpoint)
full_state_dict = {}
for name, param in reference_model.named_parameters():
# Convert parameters to checkpoint format (what load_from_full_model_state_dict expects)
if "norm" not in name and "wte" not in name and "ln_f" not in name:
# For linear layers, remove .linear from the name to simulate checkpoint format
checkpoint_name = name.replace(".linear.weight", ".weight").replace(".linear.bias", ".bias")
else:
# For norm, embedding, and layer norm layers, keep the original name
checkpoint_name = name
full_state_dict[checkpoint_name] = param.detach().clone()
# create distributed model
model = LoRAGPT(config)
mark_only_lora_as_trainable(model)
# set up device mesh for distributed model
device_mesh = init_device_mesh(
device_type=fabric.device.type,
mesh_shape=(2, 1),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)
model = parallelize_fn(model, device_mesh, activation_checkpointing=False)
model = model.to(fabric.device)
# test with default parameters (strict=False, cpu_offload=False)
result = load_from_full_model_state_dict(
model=model,
full_sd=full_state_dict,
device=fabric.device,
strict=False,
cpu_offload=False,
)
# verify that the function returns the missing/unexpected keys
assert hasattr(result, "missing_keys")
assert hasattr(result, "unexpected_keys")
# verify that parameters are loaded correctly
for name, param in model.named_parameters():
if param.requires_grad:
# Check that parameter is not on meta device
assert not param.is_meta, f"Parameter {name} should not be on meta device"
# Check that parameter is on the correct device
assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device"
# test with cpu_offload=True
model_cpu_offload = LoRAGPT(config)
mark_only_lora_as_trainable(model_cpu_offload)
model_cpu_offload = parallelize_fn(model_cpu_offload, device_mesh, activation_checkpointing=False)
model_cpu_offload = model_cpu_offload.to(fabric.device)
result_cpu_offload = load_from_full_model_state_dict(
model=model_cpu_offload,
full_sd=full_state_dict,
device=fabric.device,
strict=False,
cpu_offload=True,
)
# verify that parameters are loaded correctly with CPU offload
for name, param in model_cpu_offload.named_parameters():
if param.requires_grad:
# Check that parameter is not on meta device
assert not param.is_meta, f"Parameter {name} should not be on meta device"
# With cpu_offload, parameters might be on CPU
assert param.device.type in ["cpu", "cuda"], f"Parameter {name} should be on CPU or CUDA device"
# test with strict=True
model_strict = LoRAGPT(config)
mark_only_lora_as_trainable(model_strict)
model_strict = parallelize_fn(model_strict, device_mesh, activation_checkpointing=False)
model_strict = model_strict.to(fabric.device)
try:
result_strict = load_from_full_model_state_dict(
model=model_strict,
full_sd=full_state_dict,
device=fabric.device,
strict=True,
cpu_offload=False,
)
# If strict loading succeeds, verify parameters
for name, param in model_strict.named_parameters():
if param.requires_grad:
assert not param.is_meta, f"Parameter {name} should not be on meta device"
assert param.device.type == "cuda", f"Parameter {name} should be on CUDA device"
except RuntimeError as e:
# strict=True might fail if there are missing keys, which is expected behavior
assert "Missing key(s)" in str(e) or "Unexpected key(s)" in str(e)
# test forward pass to ensure model still works after loading
x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device)
with torch.no_grad():
output = model(x)
assert output.shape == (1, config.block_size, config.padded_vocab_size)
output_cpu_offload = model_cpu_offload(x)
assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)
================================================
FILE: tests/test_merge_lora.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import shutil
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path
from unittest import mock
import pytest
import torch
import yaml
from litgpt.lora import GPT as LoRAGPT
from litgpt.lora import lora_filter
from litgpt.model import GPT
from litgpt.scripts.merge_lora import load_lora_metadata, merge_lora
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.parametrize(
("pretrained_dtype", "lora_dtype"), [(None, None), (torch.float16, torch.float32), (torch.float16, torch.bfloat16)]
)
def test_merge_lora(tmp_path, fake_checkpoint_dir, pretrained_dtype, lora_dtype):
pretrained_checkpoint_dir = tmp_path / "pretrained"
lora_checkpoint_dir = tmp_path / "lora"
shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir)
shutil.copytree(fake_checkpoint_dir, lora_checkpoint_dir)
(lora_checkpoint_dir / "lit_model.pth").unlink() # should not already exist
shutil.rmtree(tmp_path / "checkpoints")
# Create a fake pretrained checkpoint
config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16)
with open(pretrained_checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
yaml.dump(config, fp)
base_model = GPT.from_name("pythia-14m", **config).to(dtype=pretrained_dtype)
state_dict = base_model.state_dict()
assert len(state_dict) == 40
torch.save(state_dict, pretrained_checkpoint_dir / "lit_model.pth")
# Create a fake LoRA checkpoint
lora_kwargs = dict(lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_query=True, lora_value=True)
lora_model = LoRAGPT.from_name("pythia-14m", **config, **lora_kwargs).to(dtype=lora_dtype)
state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)}
assert len(state_dict) == 6
torch.save(state_dict, lora_checkpoint_dir / "lit_model.pth.lora")
hparams = dict(checkpoint_dir=str(pretrained_checkpoint_dir), **lora_kwargs)
with open(lora_checkpoint_dir / "hyperparameters.yaml", "w", encoding="utf-8") as file:
yaml.dump(hparams, file)
shutil.copyfile(pretrained_checkpoint_dir / "model_config.yaml", lora_checkpoint_dir / "model_config.yaml")
assert set(os.listdir(tmp_path)) == {"lora", "pretrained"}
merge_lora(lora_checkpoint_dir)
assert set(os.listdir(tmp_path)) == {"lora", "pretrained"}
assert set(os.listdir(lora_checkpoint_dir)) == {
"model_config.yaml",
"lit_model.pth",
"lit_model.pth.lora",
"tokenizer.json",
"tokenizer_config.json",
"hyperparameters.yaml",
}
# Assert that the merged weights can be loaded back into the base model
merged = torch.load(lora_checkpoint_dir / "lit_model.pth")
keys = base_model.load_state_dict(merged, strict=True)
assert not keys.missing_keys
assert not keys.unexpected_keys
# Attempt to merge again
stdout = StringIO()
with redirect_stdout(stdout):
merge_lora(lora_checkpoint_dir)
assert "LoRA weights have already been merged" in stdout.getvalue()
def test_load_lora_metadata(fake_checkpoint_dir):
assert not (fake_checkpoint_dir / "hyperparameters.yaml").is_file()
with pytest.raises(FileNotFoundError, match="missing a `hyperparameters.yaml` file"):
load_lora_metadata(fake_checkpoint_dir)
hparams = dict(precision="bf16-mixed", checkpoint_dir="checkpoints/meta-llama/Llama-2-7b", lora_r=8, lora_alpha=16)
with open(fake_checkpoint_dir / "hyperparameters.yaml", "w", encoding="utf-8") as file:
yaml.dump(hparams, file)
lora_args, pretrained_dir, precision = load_lora_metadata(fake_checkpoint_dir)
assert lora_args == dict(lora_r=8, lora_alpha=16)
assert pretrained_dir == Path("checkpoints/meta-llama/Llama-2-7b")
assert precision == "bf16-mixed"
================================================
FILE: tests/test_model.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from copy import deepcopy
from functools import partial
from unittest import mock
import pytest
import torch
from lightning import Fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.init import _materialize_meta_tensors
from torch._dynamo.backends import debugging
from torch.backends.cuda import (
SDPAParams,
SDPBackend,
can_use_efficient_attention,
can_use_flash_attention,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
)
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.falcon import FalconConfig, FalconForCausalLM
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.gemma3 import Gemma3Config, Gemma3ForCausalLM, Gemma3ForConditionalGeneration, Gemma3TextConfig
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM
from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
from transformers.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
import litgpt.config as config_module
from litgpt import GPT, Config
from litgpt.model import CausalSelfAttention, batched_index_copy_
from litgpt.scripts.convert_hf_checkpoint import (
copy_weights_falcon,
copy_weights_gemma_2,
copy_weights_gemma_3,
copy_weights_gpt_neox,
copy_weights_hf_llama,
copy_weights_olmo2,
copy_weights_phi,
copy_weights_qwen_2_5,
copy_weights_qwen_3,
)
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
from litgpt.utils import _RunIf
@torch.inference_mode()
@pytest.mark.parametrize("rotary_pct", (0.25, 1))
@pytest.mark.parametrize("batch_size", (1, 3))
@pytest.mark.parametrize("n_embd", (16, 32))
@pytest.mark.parametrize("parallel_residual", (False, True))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None:
torch.set_default_dtype(dtype)
ours_config = Config(
block_size=64,
vocab_size=100,
n_layer=4,
n_head=8,
n_embd=n_embd,
rotary_percentage=rotary_pct,
parallel_residual=parallel_residual,
)
assert ours_config.padded_vocab_size == 512
theirs_config = GPTNeoXConfig(
hidden_act="gelu",
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
initializer_range=0.02,
intermediate_size=ours_config.intermediate_size,
layer_norm_eps=ours_config.norm_eps,
max_position_embeddings=ours_config.block_size,
rotary_emb_base=10000,
rotary_pct=ours_config.rotary_percentage,
vocab_size=ours_config.padded_vocab_size,
use_parallel_residual=ours_config.parallel_residual,
attn_implementation="eager",
)
state_dict = {}
theirs_model = GPTNeoXForCausalLM(theirs_config).to(device)
# load the hf initialization into our model
copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict())
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
token_sample = torch.randint(
0, ours_config.padded_vocab_size, size=(batch_size, ours_config.block_size), dtype=torch.int64, device=device
)
theirs = theirs_model(token_sample)["logits"]
ours = ours_model(token_sample)
torch.testing.assert_close(ours, theirs, rtol=1e-2, atol=1e-2)
@torch.inference_mode()
@pytest.mark.parametrize(
"kwargs",
[
dict(name="falcon-180B", n_layer=2, n_head=8, n_query_groups=4, n_embd=32),
dict(name="falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32),
],
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_falcon(kwargs, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(**kwargs)
theirs_config = FalconConfig(
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_kv_heads=ours_config.n_query_groups,
num_hidden_layers=ours_config.n_layer,
parallel_attn=ours_config.parallel_residual,
vocab_size=ours_config.padded_vocab_size,
bias=ours_config.bias,
new_decoder_architecture=True,
)
theirs_model = FalconForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_falcon(ours_config, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_open_llama_3b(device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
theirs_config = LlamaConfig(
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"ours_kwargs",
[
{"name": "Llama-2-7b-hf"},
{"name": "CodeLlama-7b-hf"},
{"name": "Llama-2-70b-chat-hf", "n_query_groups": 1},
{"name": "Llama-3-8B"},
{"name": "Llama-3-8B-Instruct"},
{"name": "Llama-3.1-405B", "n_query_groups": 4},
{"name": "Llama-3.1-8B"},
{"name": "Llama-3.1-8B-Instruct"},
{"name": "Llama-3.2-1B"},
{"name": "Llama-3.2-3B"},
{"name": "Llama-3.3-70B-Instruct"},
{"name": "R1-Distill-Llama-8B"},
{"name": "R1-Distill-Llama-70B"},
],
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)],
),
],
)
def test_against_hf_phi(model_name, device, dtype):
from transformers.models.phi.configuration_phi import PhiConfig
from transformers.models.phi.modeling_phi import PhiForCausalLM
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
T = 5
theirs_config = PhiConfig(
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
torch_dtype=dtype,
)
theirs_model = PhiForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"model_name",
(
"Phi-3-mini-4k-instruct",
"Phi-3-mini-128k-instruct",
"Phi-3.5-mini-instruct",
"phi-4",
"Phi-4-mini-instruct",
"Phi-4-reasoning",
"Phi-4-mini-reasoning",
),
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)],
),
],
)
def test_against_hf_phi_3(model_name, device, dtype):
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=4,
n_query_groups=4,
n_embd=256,
)
T = 5
theirs_config = Phi3Config(
attention_bias=ours_config.bias,
head_dim=ours_config.head_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
num_key_value_heads=ours_config.n_query_groups,
pad_token_id=ours_config.padded_vocab_size - 1,
partial_rotary_factor=ours_config.rotary_percentage,
rms_norm_eps=ours_config.norm_eps,
rope_theta=ours_config.rope_base,
torch_dtype=dtype,
vocab_size=ours_config.padded_vocab_size,
)
theirs_model = Phi3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
@pytest.mark.parametrize("model_name", ["Mistral-7B-Instruct-v0.1", "Mistral-7B-v0.1"])
def test_against_mistral_hf_models(device, dtype, model_name):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attn_implementation="eager",
sliding_window=ours_config.sliding_window_size,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_mathstral_hf_models(device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
"Mathstral-7B-v0.1",
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = MistralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1"))
def test_against_hf_mixtral(model_name):
device = torch.device("cpu")
dtype = torch.float32
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
n_head=8,
n_query_groups=2,
intermediate_size=86,
n_expert=4,
)
T = 5
theirs_config = MixtralConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
num_local_experts=ours_config.n_expert,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = MixtralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_olmo(model_name, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
intermediate_size=86,
)
T = 5
theirs_config = OlmoConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
tie_word_embeddings=(model_name == "OLMo-1B-hf"),
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = OlmoForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_olmo2(model_name, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = Olmo2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = Olmo2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_stablelm_zephyr_3b(device, dtype):
torch.set_default_dtype(dtype)
T = 5
ours_config = Config.from_name("stablelm-zephyr-3b", n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = AutoConfig.from_pretrained(
"stabilityai/stablelm-zephyr-3b",
trust_remote_code=True,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_head,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
torch_dtype=dtype,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("gemma-2-9b", "gemma-2-27b"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_2(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_logit_softcapping=ours_config.attention_logit_softcapping,
final_logit_softcapping=ours_config.final_logit_softcapping,
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
)
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_2({}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_gemma_3(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
rope_scaling={"factor": 8.0, "rope_type": "linear"},
rope_local_base_freq=ours_config.rope_local_base_freq,
)
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_multimodal_gemma_3(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
sliding_window_size=T // 2,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Gemma3Config(
Gemma3TextConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
sliding_window=ours_config.sliding_window_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
attn_implementation="eager",
query_pre_attn_scalar=ours_config.attention_scores_scalar,
rope_scaling={"factor": 8.0, "rope_type": "linear"},
rope_local_base_freq=ours_config.rope_local_base_freq,
)
)
theirs_model = Gemma3ForConditionalGeneration(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_gemma_3({}, state_dict, theirs_state_dict, config=ours_config)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
@torch.inference_mode()
@pytest.mark.parametrize(
"model_name", ["Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B"]
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_qwen_2_5(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Qwen2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.attn_bias,
tie_word_embeddings=True,
)
theirs_model = Qwen2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_qwen_2_5(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"model_name",
[
"Qwen3-0.6B",
"Qwen3-8B",
"Qwen3-4B-Base",
"Qwen3-14B-Base",
"Qwen3-32B",
"Qwen3-4B-Thinking-2507",
"Qwen3-4B-Instruct-2507",
],
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_qwen_3(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Qwen3Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
tie_word_embeddings=False,
)
theirs_model = Qwen3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize(
"model_name", ["Qwen3-30B-A3B", "Qwen3-235B-A22B", "Qwen3-235B-A22B-Thinking-2507", "Qwen3-235B-A22B-Instruct-2507"]
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_qwen_3_moe(model_name, device, dtype):
torch.set_default_dtype(dtype)
T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
moe_intermediate_size=20,
n_expert=4,
n_expert_per_token=2,
)
theirs_config = Qwen3MoeConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
moe_intermediate_size=ours_config.moe_intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
tie_word_embeddings=False,
num_experts=ours_config.n_expert,
num_experts_per_tok=ours_config.n_expert_per_token,
norm_topk_prob=True,
)
theirs_model = Qwen3MoeForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_salamandra(model_name, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_smollm2(model_name, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_falcon3(model_name, device, dtype):
torch.set_default_dtype(dtype)
ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size
theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
@_RunIf(dynamo=True)
@torch.inference_mode()
def test_model_compile():
model = GPT.from_name("pythia-14m", n_layer=3)
x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)
explanation = torch._dynamo.explain(model)(x)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
model = GPT(model.config)
model.set_kv_cache(2)
input_pos = torch.arange(model.config.block_size)
explanation = torch._dynamo.explain(model)(x, input_pos)
assert isinstance(explanation, debugging.ExplainOutput)
assert explanation.graph_count == 1
assert explanation.graph_break_count == 0
@torch.inference_mode()
@pytest.mark.parametrize(
"max_seq_length", (25, pytest.param(23, marks=pytest.mark.xfail(raises=IndexError, strict=True)))
)
@pytest.mark.flaky(reruns=5)
def test_kv_cache(max_seq_length):
config = Config(block_size=25, padded_vocab_size=5, n_layer=2, n_head=2, n_embd=8)
model = GPT(config)
idx = torch.randint(0, model.config.padded_vocab_size, (1, 5))
max_new_tokens = 20
model.max_seq_length = max_seq_length
model.set_kv_cache(1)
def generate(logits):
logits = logits[:, -1:]
probs = torch.nn.functional.softmax(logits, dim=-1)
return torch.argmax(probs).unsqueeze(0).unsqueeze(0)
x_no_cache = idx
x_cache = idx
input_pos = torch.arange(0, 5)
for _ in range(max_new_tokens):
logits_no_cache = model(x_no_cache[:, -max_seq_length:])
out_no_cache = generate(logits_no_cache)
logits_cache = model(x_cache, input_pos)
out_cache = generate(logits_cache)
torch.testing.assert_close(out_no_cache, out_cache, rtol=0, atol=0)
x_no_cache = torch.cat((x_no_cache, out_no_cache), dim=1)
x_cache = out_cache
input_pos = input_pos[-1:] + 1
@torch.inference_mode()
def test_model_kv_cache_amp():
config = Config.from_name("pythia-14m", n_layer=2)
model = GPT(config)
encoded = torch.arange(45)
model.set_kv_cache(batch_size=1)
with torch.autocast("cpu", torch.bfloat16):
output = model(encoded.unsqueeze(0), encoded)
assert output.dtype is torch.bfloat16
@pytest.mark.parametrize("model_name", ["pythia-14m", "gemma-3-1b-it"])
def test_rope_cache_length(model_name):
config = Config.from_name(model_name, n_layer=2)
model = GPT(config)
model.max_seq_length = 128
rope_len = model.rope_cache_length()
assert rope_len == config.rope_n_elem
# Verify it works with set_kv_cache
model.set_kv_cache(batch_size=1)
assert model.transformer.h[0].attn.kv_cache is not None
# https://github.com/pytorch/pytorch/blob/ad3572a5d/torch/testing/_internal/common_cuda.py#L31-L34
SUPPORTS_FLASH_ATTENTION = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) and not _IS_WINDOWS
)
@_RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice(config):
if config["name"].startswith("Gemma-2-"):
pytest.skip("Gemma 2 doesn't support SDPA")
torch.set_default_dtype(torch.float16)
def assert_sdpa_backend(original_fn, q, k, v, mask):
# SDPAParams gained an additional argument in PyTorch 2.5
args = []
if hasattr(SDPAParams, "enable_gqa"):
args.append(False)
params = SDPAParams(q, k, v, mask, 0.0, True, *args)
if expected is SDPBackend.FLASH_ATTENTION:
assert flash_sdp_enabled(), "flash_sdp_enabled() is False"
if config.sliding_window_size is None:
assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False"
elif expected is SDPBackend.EFFICIENT_ATTENTION:
assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False"
assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False"
elif expected is SDPBackend.MATH:
assert math_sdp_enabled(), "math_sdp_enabled() is False"
else:
raise NotImplementedError
return original_fn(q, k, v, mask)
config["n_layer"] = 1
config = config_module.Config(**config)
try:
with torch.device("cuda"):
model = GPT(config)
x = torch.randint(0, 10, (2, 16), dtype=torch.int32)
except torch.cuda.OutOfMemoryError:
# best effort, if the GPU can load it
pytest.xfail()
for h in model.transformer.h:
h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention)
if SUPPORTS_FLASH_ATTENTION:
expected = SDPBackend.FLASH_ATTENTION
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
model(x)
expected = SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 else SDPBackend.MATH
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x)
@_RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("config", deepcopy(config_module.configs), ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice_kv_cache(config):
torch.set_default_dtype(torch.float16)
def assert_sdpa_backend(original_fn, q, k, v, mask):
# SDPAParams gained an additional argument in PyTorch 2.5
args = []
if hasattr(SDPAParams, "enable_gqa"):
args.append(False)
params = SDPAParams(q, k, v, mask, 0.0, True, *args)
if expected is SDPBackend.FLASH_ATTENTION:
assert flash_sdp_enabled()
assert can_use_flash_attention(params, True)
elif expected is SDPBackend.EFFICIENT_ATTENTION:
assert mem_efficient_sdp_enabled()
assert can_use_efficient_attention(params, True)
elif expected is SDPBackend.MATH:
assert math_sdp_enabled()
else:
raise NotImplementedError
return original_fn(q, k, v, mask)
config["n_layer"] = 1
config = config_module.Config(**config)
try:
with torch.device("cuda"):
model = GPT(config)
model.max_seq_length = 1
model.set_kv_cache(2)
x = torch.randint(0, 10, (2, 1), dtype=torch.int32)
input_pos = torch.tensor([0], dtype=torch.long)
except torch.cuda.OutOfMemoryError:
# best effort, if the GPU can load it
pytest.xfail()
for h in model.transformer.h:
h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention)
if SUPPORTS_FLASH_ATTENTION:
# flash attention does not support an attention mask
expected = SDPBackend.MATH
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
model(x, input_pos)
expected = (
SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 and config.n_query_groups != 1 else SDPBackend.MATH
)
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x, input_pos)
@_RunIf(min_cuda_gpus=2, standalone=True)
def test_rope_init_under_fsdp():
"""Check that the rope cache is properly initialized"""
fabric = Fabric(devices=2, strategy="fsdp", accelerator="cuda")
fabric.launch()
with fabric.init_module(empty_init=True):
model = GPT.from_name("pythia-14m", n_layer=1)
assert model.cos.device.type == "meta"
assert model.sin.device.type == "meta"
model = fabric.setup(model)
assert model.cos.device.type == "cuda"
assert model.sin.device.type == "cuda"
cos, sin = model.rope_cache(device=fabric.device)
torch.testing.assert_close(model.cos, cos)
torch.testing.assert_close(model.sin, sin)
@_RunIf(min_cuda_gpus=1)
def test_reset_parameters_device():
with torch.device("meta"):
model = GPT.from_name("pythia-14m", n_layer=1)
_materialize_meta_tensors(model, torch.device("cuda"))
model.reset_parameters()
assert model.cos.device.type == "cuda"
def test_batched_index_copy_modes():
# Mock the torch.backends.mps.is_available() function to simulate MPS availability
with mock.patch("torch.backends.mps.is_available", return_value=True):
# Mock the device type to simulate the "mps" device
with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device:
mock_device.return_value = torch.device("mps")
# Test case when idx.dim() == 1
t_original_1 = torch.randn(3, 5)
dim_1 = 0
idx_1 = torch.tensor([0, 2])
val_1 = torch.randn(2, 5)
t1_cpu = t_original_1.clone()
t1_mps = t_original_1.clone()
# Perform the index copy on CPU
batched_index_copy_(t1_cpu, dim_1, idx_1, val_1)
# Simulate the MPS index copy
idx_1_mps = idx_1
val_1_mps = val_1
batched_index_copy_(t1_mps, dim_1, idx_1_mps, val_1_mps)
assert torch.allclose(t1_cpu, t1_mps), "Mismatch with idx.dim() == 1 on mocked MPS"
# Test case when idx.dim() == 2
t_original_2 = torch.randn(2, 5, 4)
dim_2 = 1
idx_2 = torch.tensor([[0, 2], [1, 3]])
val_2 = torch.randn(2, 2, 4)
t2_cpu = t_original_2.clone()
t2_mps = t_original_2.clone()
# Perform the index copy on CPU
batched_index_copy_(t2_cpu, dim_2, idx_2, val_2)
# Simulate the MPS index copy
idx_2_mps = idx_2
val_2_mps = val_2
batched_index_copy_(t2_mps, dim_2, idx_2_mps, val_2_mps)
assert torch.allclose(t2_cpu, t2_mps), "Mismatch with idx.dim() == 2 on mocked MPS"
# Additional test with negative dimension
t_original_3 = torch.randn(2, 3, 4)
dim_3 = -2
idx_3 = torch.tensor([[0, 1], [1, 2]])
val_3 = torch.randn(2, 2, 4)
t3_cpu = t_original_3.clone()
t3_mps = t_original_3.clone()
# Perform the index copy on CPU
batched_index_copy_(t3_cpu, dim_3, idx_3, val_3)
# Simulate the MPS index copy
idx_3_mps = idx_3
val_3_mps = val_3
batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps)
assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS"
def test_load_legacy_state_dict():
"""Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers."""
config = Config(
n_embd=32,
n_head=4,
head_size=8,
n_query_groups=4,
bias=True,
)
attention_1 = CausalSelfAttention(config=config, block_idx=0)
# make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`
# and make them interleaved
state_dict = deepcopy(attention_1.state_dict())
state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config)
state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config)
attention_2 = CausalSelfAttention(config=config, block_idx=0)
attention_2.load_state_dict(state_dict)
@pytest.mark.parametrize("n_query_groups", (1, 2, 4, 8))
@torch.inference_mode()
def test_kv_cache_buffer_shape(n_query_groups):
batch_size = 3
max_seq_length = 23
config = Config(
block_size=25,
padded_vocab_size=5,
n_layer=2,
n_head=8,
n_embd=16,
n_query_groups=n_query_groups,
)
model = GPT(config)
model.max_seq_length = max_seq_length
model.set_kv_cache(batch_size)
required_shape = (batch_size, n_query_groups, max_seq_length, config.head_size)
for block in model.transformer.h:
kv_cache = block.attn.kv_cache
assert kv_cache is not None
assert kv_cache.k.shape == required_shape
assert kv_cache.v.shape == required_shape
@pytest.mark.parametrize(("rotary_percentage", "final_dim"), ((0.75, 3), (0.25, 2)))
@torch.inference_mode()
def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim):
batch_size = 3
config = Config(
block_size=25,
padded_vocab_size=5,
n_layer=2,
n_head=4,
n_embd=16,
rotary_percentage=rotary_percentage,
)
model = GPT(config)
required_shape = (config.block_size, final_dim)
assert model.cos.shape == required_shape
assert model.sin.shape == required_shape
def test_forward_with_without_input_pos_maxp1():
batch_size = 3
config = Config(
block_size=25,
padded_vocab_size=5,
n_layer=2,
n_head=8,
n_embd=16,
)
model = GPT(config)
model.set_kv_cache(batch_size)
idx = torch.randint(0, config.padded_vocab_size, (1, 10))
input_pos = torch.arange(1, 11)
input_pos_maxp1 = 11
logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1)
logits_no_maxp1 = model(idx, input_pos)
torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)
================================================
FILE: tests/test_multihead_latent_attention.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import pytest
import torch
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM
from litgpt import Config
from litgpt.model import MultiheadLatentAttention
@torch.inference_mode()
def test_multihead_latent_attention_kv_cache():
"""Test KV cache functionality"""
config = Config(
block_size=32,
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
)
mla = MultiheadLatentAttention(config, block_idx=0)
# Build KV cache
kv_cache = mla.build_kv_cache(batch_size=2, max_seq_length=32, device=torch.device("cpu"), dtype=torch.float32)
# Check cache shapes
assert kv_cache.k.shape == (2, config.n_head, 32, config.qk_head_dim)
assert kv_cache.v.shape == (2, config.n_head, 32, config.v_head_dim)
@torch.inference_mode()
def test_multihead_latent_attention_with_mask():
"""Test attention with causal mask"""
config = Config(
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
)
mla = MultiheadLatentAttention(config, block_idx=0)
batch_size, seq_len = 1, 8
x = torch.randn(batch_size, seq_len, config.n_embd)
cos = torch.randn(1, seq_len, config.qk_rope_head_dim)
sin = torch.randn(1, seq_len, config.qk_rope_head_dim)
# Create causal mask
mask = torch.ones(seq_len, seq_len, dtype=x.dtype).triu(diagonal=1)
mask.masked_fill_(mask.bool(), float("-inf"))
mask = mask.view(1, 1, seq_len, seq_len)
# Forward pass with mask
output = mla(x, cos, sin, mask=mask)
assert output.shape == (batch_size, seq_len, config.n_embd)
@torch.inference_mode()
@pytest.mark.parametrize("batch_size", (1, 2))
@pytest.mark.parametrize("seq_len", (8, 16))
@pytest.mark.parametrize("device", [torch.device("cpu")])
def test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device):
"""Test MLA litgpt vs hf"""
config_litgpt = Config(
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
norm_eps=1e-6,
bias=False,
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
)
config_hf = DeepseekV3Config(
padded_vocab_size=10000,
num_hidden_layers=1,
vocab_size=10000,
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=16,
qk_rope_head_dim=8,
qk_nope_head_dim=8,
v_head_dim=16,
rope_interleave=False,
)
mla_litgpt = MultiheadLatentAttention(config_litgpt, block_idx=0).to(device)
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
mla_hf = model_hf.model.layers[0].self_attn
mla_litgpt.eval()
mla_hf.eval()
sync_weights(mla_litgpt, mla_hf)
hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)
# Prepare RoPE sin/cos tables
rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"]
cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)
sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
# Run forward passes
output_litgpt = mla_litgpt(hidden_states, cos, sin)
output_hf = mla_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)[0]
assert torch.allclose(output_litgpt, output_hf, atol=1e-5)
def sync_weights(litgpt_model, hf_model):
"""Copies weights from lit-gpt model to HF model."""
print("Synchronizing weights...")
with torch.no_grad():
hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight)
hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight)
hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight)
hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight)
hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight)
hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight)
hf_model.o_proj.weight.copy_(litgpt_model.proj.weight)
print("Synchronization complete.")
================================================
FILE: tests/test_pretrain.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import ANY, Mock
import pytest
import torch
from lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy
from torch.utils.data import DataLoader
from litgpt import pretrain
from litgpt.args import EvalArgs, TrainArgs
from litgpt.config import Config
from litgpt.pretrain import initialize_weights
from litgpt.utils import _RunIf
@_RunIf(min_cuda_gpus=1, standalone=True)
@mock.patch("litgpt.pretrain.save_hyperparameters")
def test_optimizer_args(_, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8)
dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))
for i in ("AdamW", "SGD", "RMSprop"):
pretrain.setup(
"pythia-14m",
devices=1,
optimizer="RMSprop",
model_config=model_config,
out_dir=tmp_path,
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1, final_validation=False),
)
@_RunIf(min_cuda_gpus=2, standalone=True)
# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise
# the CLI would capture pytest args, but unfortunately patching would mess with subprocess
# launching, so we need to mock `save_hyperparameters()`
@mock.patch("litgpt.pretrain.save_hyperparameters")
# todo: it expects exactly 2 GPUs and has strange failing for validated 4 # GPUs, so we temporarily mark it as xfail
@pytest.mark.xfail(condition=torch.cuda.device_count() != 2, reason="This test is flaky, expects exactly 2 GPUs")
def test_pretrain(_, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))
out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
pretrain.setup(
"pythia-14m",
devices=2,
model_config=model_config,
out_dir=out_dir,
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1, final_validation=False),
)
if torch.distributed.get_rank() == 0:
# tmp_path is not the same across all ranks, run assert only on rank 0
out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-00000001", "step-00000002", "step-00000003", "step-00000004", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
# the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here
assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "model_config.yaml"}
assert (out_dir / "logs" / "tensorboard" / "version_0").is_dir()
# logs only appear on rank 0
logs = stdout.getvalue()
assert logs.count("(step)") == 4
assert logs.count("val loss") == 4
assert "Total parameters: 1,888" in logs
torch.distributed.barrier()
@_RunIf(min_cuda_gpus=2, standalone=True)
@mock.patch("litgpt.pretrain.L.Fabric.load_raw")
# See comment in `test_pretrain` why we need to mock `save_hyperparameters()`
@mock.patch("litgpt.pretrain.save_hyperparameters")
def test_initial_checkpoint_dir(_, load_mock, tmp_path):
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))
pretrain.fit = Mock()
pretrain.setup(
"pythia-14m",
initial_checkpoint_dir=tmp_path,
devices=torch.cuda.device_count(),
model_config=model_config,
out_dir=tmp_path,
)
load_mock.assert_called_once_with(tmp_path / "lit_model.pth", ANY)
@pytest.mark.parametrize(("strategy", "expected"), [(SingleDeviceStrategy, True), (FSDPStrategy, False)])
def test_initialize_weights(strategy, expected):
fabric_mock = Mock()
fabric_mock.strategy = Mock(spec=strategy)
class Child(torch.nn.Module):
pass
class Parent(torch.nn.Module):
def __init__(self):
super().__init__()
self.child = Child()
model = Parent()
model.reset_parameters = Mock()
model.child.reset_parameters = Mock()
initialize_weights(fabric_mock, model, n_layer=2, n_embd=8)
assert model.reset_parameters.call_count == int(expected)
assert model.child.reset_parameters.call_count == int(expected)
================================================
FILE: tests/test_prompts.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from typing import Optional
import pytest
import yaml
import litgpt.config
from litgpt import Config
from litgpt.prompts import (
Alpaca,
Default,
Llama3,
Phi3,
PromptStyle,
has_prompt_style,
load_prompt_style,
prompt_styles,
save_prompt_style,
)
def test_default_prompt_style(mock_tokenizer):
prompt_style = Default()
prompt = "This is a test prompt."
assert prompt_style.apply(prompt) == prompt
assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)
@pytest.mark.parametrize("sys_prompt", [None, "You are a helpful coding assistant."])
def test_sys_prompt(mock_tokenizer, sys_prompt: Optional[str]):
prompt_style = Phi3()
prompt = "This is a test prompt."
default_sys_prompt = "You are a helpful assistant."
response = f"<|system|>\n{sys_prompt or default_sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
assert prompt_style.apply(prompt, sys_prompt=sys_prompt) == response
assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)
@pytest.mark.parametrize("sys_prompt", [None, "You are a helpful coding assistant."])
def test_sys_prompt_with_kwargs(mock_tokenizer, sys_prompt: Optional[str]):
prompt_style = Phi3()
prompt = "This is a test prompt."
default_sys_prompt = "You are a helpful assistant."
response = f"<|system|>\n{sys_prompt or default_sys_prompt}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
assert prompt_style.apply(prompt, sys_prompt=sys_prompt, test=1) == response
assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)
def test_prompt_style_from_name():
for style_name in prompt_styles:
assert isinstance(PromptStyle.from_name(style_name), prompt_styles[style_name])
def test_prompt_style_from_config():
model_names = [
"stablelm-tuned-alpha-3b",
"stablelm-tuned-alpha-7b",
"stablelm-zephyr-3b",
"stablecode-instruct-alpha-3b",
"falcon-7b-instruct",
"falcon-40b-instruct",
"Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf",
"Llama-3-8B-Instruct",
"Llama-3-70B-Instruct",
"Llama-3.1-405B-Instruct",
"Gemma-2b-it",
"Gemma-7b-it",
"FreeWilly2",
"CodeLlama-7b-Instruct-hf",
"CodeLlama-13b-Instruct-hf",
"CodeLlama-34b-Instruct-hf",
"CodeLlama-70b-Instruct-hf",
"phi-1_5",
"phi-2",
"Phi-3-mini-4k-instruct",
"Mistral-7B-Instruct-v0.1",
"Mistral-7B-Instruct-v0.2",
"tiny-llama-1.1b-chat",
"Llama-2-7b-chat-hf-function-calling-v2",
]
for c in litgpt.config.platypus:
model_names.append(c["name"])
for model_name in model_names:
# by asserting the returned style is not the Default, we show that at least one of the regex patterns matched
assert not isinstance(PromptStyle.from_config(Config.from_name(model_name)), Default)
def test_apply_prompts():
prompt = "Is a coconut a nut or a fruit?"
inp = "Optional input"
for style in prompt_styles.values():
output = style().apply(prompt, input=inp)
assert prompt in output
if isinstance(style, Alpaca):
assert inp in output
class CustomPromptStyle(PromptStyle):
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:
return prompt
def test_save_load_prompt_style(tmp_path):
# Save and load a built-in style
checkpoint_dir = tmp_path / "checkpoint"
checkpoint_dir.mkdir()
assert not has_prompt_style(checkpoint_dir)
save_prompt_style("alpaca", checkpoint_dir)
assert has_prompt_style(checkpoint_dir)
with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file:
contents = yaml.safe_load(file)
assert contents == {"class_path": "litgpt.prompts.Alpaca"}
loaded = load_prompt_style(checkpoint_dir)
assert isinstance(loaded, Alpaca)
# Save a custom style
checkpoint_dir = tmp_path / "custom"
checkpoint_dir.mkdir()
save_prompt_style(CustomPromptStyle(), checkpoint_dir)
with open(checkpoint_dir / "prompt_style.yaml", encoding="utf-8") as file:
contents = yaml.safe_load(file)
assert contents == {"class_path": "test_prompts.CustomPromptStyle"}
loaded = load_prompt_style(checkpoint_dir)
assert isinstance(loaded, CustomPromptStyle)
def test_multiturn_prompt():
prompt = "What is the capital of France?"
msgs = [{"role": "user", "content": prompt}]
style = Llama3()
simple_output = style.apply(prompt)
multiturn_output = style.apply(msgs)
assert simple_output == multiturn_output
# override system prompt
msgs = [{"role": "system", "content": "You are not a helpful assistant."}, {"role": "user", "content": prompt}]
with_system_multiturn_output = style.apply(msgs)
assert "You are not a helpful assistant." in with_system_multiturn_output
# use default system prompt
msgs = [
{"role": "user", "content": prompt},
]
wo_system_multiturn_output = style.apply(msgs)
assert "You are a helpful assistant." in wo_system_multiturn_output
# Longer turn
msgs = [
{"role": "system", "content": "You are a helpful AI assistant for travel tips and recommendations"},
{"role": "user", "content": "What is France's capital?"},
{"role": "assistant", "content": "Bonjour! The capital of France is Paris!"},
{"role": "user", "content": "What can I do there?"},
]
multiturn_output = style.apply(msgs)
assert (
multiturn_output
== """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>
What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>
What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
)
# Longer list without "system"
msgs = [
{"role": "user", "content": "What is France's capital?"},
{"role": "assistant", "content": "Bonjour! The capital of France is Paris!"},
{"role": "user", "content": "What can I do there?"},
]
multiturn_output = style.apply(msgs)
assert (
multiturn_output
== """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>
What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
)
# {random} string format shouldn't lead to key error
content = "this is {random} {system} {user}"
msgs = [{"role": "user", "content": content}]
output = style.apply(msgs)
simple_output = style.apply(content)
assert output == simple_output
================================================
FILE: tests/test_readme.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import platform
import subprocess
import sys
import threading
import time
from pathlib import Path
from unittest import mock
import pytest
import requests
from urllib3.exceptions import MaxRetryError
from litgpt.utils import _RunIf, kill_process_tree
REPO_ID = Path("EleutherAI/pythia-14m")
CUSTOM_TEXTS_DIR = Path("custom_texts")
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
return result.stdout
except subprocess.CalledProcessError as e:
error_message = (
f"Command '{' '.join(command)}' failed with exit status {e.returncode}\n"
f"Output:\n{e.stdout}\n"
f"Error:\n{e.stderr}"
)
# You can either print the message, log it, or raise an exception with it
print(error_message)
raise RuntimeError(error_message) from None
def _wait_and_check_response(waiting: int = 30):
response_status_code, err = -1, None
for _ in range(waiting):
try:
response = requests.get("http://127.0.0.1:8000", timeout=1)
response_status_code = response.status_code
except (MaxRetryError, requests.exceptions.ConnectionError) as ex:
response_status_code = -1
err = str(ex)
if response_status_code == 200:
break
time.sleep(1)
assert response_status_code == 200, "Server did not respond as expected. Error: {err}"
@pytest.mark.dependency()
@pytest.mark.flaky(reruns=5, reruns_delay=2)
def test_download_model():
repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI
command = ["litgpt", "download", str(repo_id)]
output = run_command(command)
s = Path("checkpoints") / repo_id
assert f"Saving converted checkpoint to {str(s)}" in output
assert ("checkpoints" / REPO_ID).exists()
# Also test valid but unsupported repo IDs
command = ["litgpt", "download", "CohereForAI/aya-23-8B"]
output = run_command(command)
assert "Unsupported `repo_id`" in output
@pytest.mark.dependency()
@pytest.mark.flaky(reruns=5, reruns_delay=2)
def test_download_books():
CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True)
books = [
("https://www.gutenberg.org/cache/epub/24440/pg24440.txt", "book1.txt"),
("https://www.gutenberg.org/cache/epub/26393/pg26393.txt", "book2.txt"),
]
for url, filename in books:
subprocess.run(["curl", url, "--output", str(CUSTOM_TEXTS_DIR / filename)], check=True)
# Verify each book is downloaded
assert (CUSTOM_TEXTS_DIR / filename).exists(), f"{filename} not downloaded"
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model"])
def test_chat_with_model():
command = ["litgpt", "generate", "checkpoints" / REPO_ID]
prompt = "What do Llamas eat?"
result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)
assert "What food do llamas eat?" in result.stdout
@_RunIf(min_cuda_gpus=1)
@pytest.mark.dependency(depends=["test_download_model"])
def test_chat_with_quantized_model():
command = ["litgpt", "generate", "checkpoints" / REPO_ID, "--quantize", "bnb.nf4", "--precision", "bf16-true"]
prompt = "What do Llamas eat?"
result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)
assert "What food do llamas eat?" in result.stdout, result.stdout
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model"])
@pytest.mark.timeout(300)
def test_finetune_model(tmp_path):
OUT_DIR = tmp_path / "out" / "lora"
DATASET_PATH = tmp_path / "custom_finetuning_dataset.json"
CHECKPOINT_DIR = "checkpoints" / REPO_ID
download_command = [
"curl",
"-L",
"https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json",
"-o",
str(DATASET_PATH),
]
subprocess.run(download_command, check=True)
assert DATASET_PATH.exists(), "Dataset file not downloaded"
finetune_command = [
"litgpt",
"finetune_lora",
str(CHECKPOINT_DIR),
"--lora_r",
"1",
"--data",
"JSON",
"--data.json_path",
str(DATASET_PATH),
"--data.val_split_fraction",
"0.00001", # Keep small because new final validation is expensive
"--train.max_steps",
"1",
"--out_dir",
str(OUT_DIR),
]
run_command(finetune_command)
generated_out_dir = OUT_DIR / "final"
assert generated_out_dir.exists(), f"Finetuning output directory ({generated_out_dir}) was not created"
model_file = OUT_DIR / "final" / "lit_model.pth"
assert model_file.exists(), f"Model file ({model_file}) was not created"
@pytest.mark.skipif(
sys.platform.startswith("win") or sys.platform == "darwin", reason="`torch.compile` is not supported on this OS."
)
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_pretrain_model(tmp_path):
OUT_DIR = tmp_path / "out" / "custom_pretrained"
pretrain_command = [
"litgpt",
"pretrain",
"pythia-14m",
"--tokenizer_dir",
str("checkpoints" / REPO_ID),
"--data",
"TextFiles",
"--data.train_data_path",
str(CUSTOM_TEXTS_DIR),
"--train.max_tokens",
"100", # to accelerate things for CI
"--eval.max_iters",
"1", # to accelerate things for CI
"--out_dir",
str(OUT_DIR),
]
output = run_command(pretrain_command)
assert "Warning: Preprocessed training data found" not in output
out_dir_path = OUT_DIR / "final"
assert out_dir_path.exists(), f"Pretraining output directory ({out_dir_path}) was not created"
out_model_path = OUT_DIR / "final" / "lit_model.pth"
assert out_model_path.exists(), f"Model file ({out_model_path}) was not created"
# Test that warning is displayed when running it a second time
output = run_command(pretrain_command)
assert "Warning: Preprocessed training data found" in output
@pytest.mark.skipif(
sys.platform.startswith("win") or sys.platform == "darwin", reason="`torch.compile` is not supported on this OS."
)
@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"})
@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_continue_pretrain_model(tmp_path):
OUT_DIR = tmp_path / "out" / "custom_continue_pretrained"
pretrain_command = [
"litgpt",
"pretrain",
"pythia-14m",
"--initial_checkpoint",
str("checkpoints" / REPO_ID),
"--tokenizer_dir",
str("checkpoints" / REPO_ID),
"--data",
"TextFiles",
"--data.train_data_path",
str(CUSTOM_TEXTS_DIR),
"--train.max_tokens",
"100", # to accelerate things for CI
"--eval.max_iters",
"1", # to accelerate things for CI
"--out_dir",
str(OUT_DIR),
]
run_command(pretrain_command)
generated_out_dir = OUT_DIR / "final"
assert generated_out_dir.exists(), f"Continued pretraining directory ({generated_out_dir}) was not created"
model_file = OUT_DIR / "final" / "lit_model.pth"
assert model_file.exists(), f"Model file ({model_file}) was not created"
@pytest.mark.dependency(depends=["test_download_model"])
# todo: try to resolve this issue
@pytest.mark.xfail(condition=platform.system() == "Darwin", reason="it passes locally but having some issues on CI")
def test_serve():
CHECKPOINT_DIR = str("checkpoints" / REPO_ID)
run_command = ["litgpt", "serve", str(CHECKPOINT_DIR)]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate(timeout=60)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
if process:
kill_process_tree(process.pid)
server_thread.join()
================================================
FILE: tests/test_rope.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import torch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXRotaryEmbedding
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama
from litgpt.model import apply_rope, build_rope_cache
@torch.inference_mode()
def test_rope_gptneox():
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head # 4
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
position_ids = torch.arange(seq_len).unsqueeze(0)
config = GPTNeoXConfig(num_attention_heads=n_head, hidden_size=head_size * n_embed)
theirs_rot_emb = GPTNeoXRotaryEmbedding(config)
theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)
ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
ours_cos_cached = ours_cos_cached.unsqueeze(0)
ours_sin_cached = ours_sin_cached.unsqueeze(0)
torch.testing.assert_close(ours_cos_cached, theirs_cos)
torch.testing.assert_close(ours_sin_cached, theirs_sin)
ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)
torch.testing.assert_close(ours_x_rope, theirs_x_rope)
@torch.inference_mode()
def test_rope_llama_2():
head_dim = 64
rope_theta = 10_000
##################################
# Compare cos and sin
##################################
# transformer rope
their_rope_config = {
"rope_type": "default",
}
config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config)
rot_emb = LlamaRotaryEmbedding(config=config)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
ours_cos = ours_cos.unsqueeze(0)
ours_sin = ours_sin.unsqueeze(0)
torch.testing.assert_close(theirs_cos, ours_cos)
torch.testing.assert_close(theirs_sin, ours_sin)
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3():
head_dim = 64
rope_theta = 50_000
##################################
# Compare cos and sin
##################################
# transformer rope
their_rope_config = {
"rope_type": "default",
}
config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config)
rot_emb = LlamaRotaryEmbedding(config=config)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
ours_cos = ours_cos.unsqueeze(0)
ours_sin = ours_sin.unsqueeze(0)
torch.testing.assert_close(theirs_cos, ours_cos)
torch.testing.assert_close(theirs_sin, ours_sin)
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_1():
head_dim = 32
rope_theta = 50_000
their_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
our_rope_config = {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_seq_len": 8192}
config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)
##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(config=config)
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
ours_cos = ours_cos.unsqueeze(0)
ours_sin = ours_sin.unsqueeze(0)
torch.testing.assert_close(theirs_cos, ours_cos)
torch.testing.assert_close(theirs_sin, ours_sin)
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_2():
head_dim = 32
rope_theta = 50_000
their_rope_config = {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
our_rope_config = {"factor": 32.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_seq_len": 8192}
config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)
##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(config=config)
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
ours_cos = ours_cos.unsqueeze(0)
ours_sin = ours_sin.unsqueeze(0)
torch.testing.assert_close(theirs_cos, ours_cos)
torch.testing.assert_close(theirs_sin, ours_sin)
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
# See https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_gemma_3():
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb
head_dim = 32
rope_theta = 50_000
their_rope_config = {
"factor": 8.0,
"rope_type": "linear",
}
our_rope_config = {"factor": 8.0}
##################################
# Compare cos and sin
##################################
# transformer rope
config = Gemma3TextConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)
rot_emb = Gemma3RotaryEmbedding(config=config)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
ours_cos = ours_cos.unsqueeze(0)
ours_sin = ours_sin.unsqueeze(0)
torch.testing.assert_close(theirs_cos, ours_cos)
torch.testing.assert_close(theirs_sin, ours_sin)
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4
# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)
ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
@torch.inference_mode()
def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd():
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head # 4
rotary_percentage = 0.75
rope_n_elem = int(head_size * rotary_percentage) # 3
ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem)
required_shape = (seq_len, rope_n_elem)
assert ours_cos.shape == required_shape
assert ours_sin.shape == required_shape
# Special case: If `rope_n_elem == 1`, the shape is extended. This is to
# accommodate a current bug in Hugging Face, ensuring that other unit tests
# pass.
# https://github.com/huggingface/transformers/issues/35233
rotary_percentage = 0.25
rope_n_elem = int(head_size * rotary_percentage) # 1
ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem)
required_shape = (seq_len, rope_n_elem + 1)
assert ours_cos.shape == required_shape
assert ours_sin.shape == required_shape
================================================
FILE: tests/test_serve.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
import platform
import shutil
import subprocess
import threading
import time
from dataclasses import asdict
import pytest
import requests
import torch
import yaml
from lightning.fabric import seed_everything
from urllib3.exceptions import MaxRetryError
from litgpt import GPT, Config
from litgpt.scripts.download import download_from_hub
from litgpt.utils import _RunIf, kill_process_tree
def _wait_and_check_response(waiting: int = 30):
response_status_code, err = -1, None
for _ in range(waiting):
try:
response = requests.get("http://127.0.0.1:8000", timeout=10)
response_status_code = response.status_code
except (MaxRetryError, requests.exceptions.ConnectionError) as ex:
response_status_code = -1
err = str(ex)
if response_status_code == 200:
break
time.sleep(1)
assert response_status_code == 200, f"Server did not respond as expected. Error: {err}"
# todo: try to resolve this issue
@pytest.mark.flaky(reruns=2, reruns_delay=30)
@pytest.mark.xfail(condition=platform.system() == "Darwin", reason="it passes locally but having some issues on CI")
def test_simple(tmp_path):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
run_command = ["litgpt", "serve", tmp_path]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response(waiting=60)
if process:
kill_process_tree(process.pid)
server_thread.join()
@_RunIf(min_cuda_gpus=1)
def test_quantize(tmp_path):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
run_command = ["litgpt", "serve", tmp_path, "--quantize", "bnb.nf4"]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
if process:
kill_process_tree(process.pid)
server_thread.join()
@_RunIf(min_cuda_gpus=2)
def test_multi_gpu_serve(tmp_path):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
run_command = ["litgpt", "serve", tmp_path, "--devices", "2"]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
if process:
kill_process_tree(process.pid)
server_thread.join()
@_RunIf(min_cuda_gpus=1)
def test_serve_with_openai_spec_missing_chat_template(tmp_path):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
if process:
kill_process_tree(process.pid)
server_thread.join()
@_RunIf(min_cuda_gpus=1)
def test_serve_with_openai_spec(tmp_path):
seed_everything(123)
ours_config = Config.from_name("SmolLM2-135M-Instruct")
download_from_hub(repo_id="HuggingFaceTB/SmolLM2-135M-Instruct", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
try:
# Test server health
response = requests.get("http://127.0.0.1:8000/health")
assert response.status_code == 200, f"Server health check failed with status code {response.status_code}"
assert response.text == "ok", "Server did not respond as expected."
# Test non-streaming chat completion
response = requests.post(
"http://127.0.0.1:8000/v1/chat/completions",
json={
"model": "SmolLM2-135M-Instruct",
"messages": [{"role": "user", "content": "Hello!"}],
},
)
assert response.status_code == 200, (
f"Non-streaming chat completion failed with status code {response.status_code}"
)
response_json = response.json()
assert "choices" in response_json, "Response JSON does not contain 'choices'."
assert "message" in response_json["choices"][0], "Response JSON does not contain 'message' in 'choices'."
assert "content" in response_json["choices"][0]["message"], (
"Response JSON does not contain 'content' in 'message'."
)
assert response_json["choices"][0]["message"]["content"], "Content is empty in the response."
# Test streaming chat completion
stream_response = requests.post(
"http://127.0.0.1:8000/v1/chat/completions",
json={
"model": "SmolLM2-135M-Instruct",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": True,
},
)
assert stream_response.status_code == 200, (
f"Streaming chat completion failed with status code {stream_response.status_code}"
)
for line in stream_response.iter_lines():
decoded = line.decode("utf-8").replace("data: ", "").replace("[DONE]", "").strip()
if decoded:
data = json.loads(decoded)
assert "choices" in data, "Response JSON does not contain 'choices'."
assert "delta" in data["choices"][0], "Response JSON does not contain 'delta' in 'choices'."
assert "content" in data["choices"][0]["delta"], "Response JSON does not contain 'content' in 'delta'."
finally:
if process:
kill_process_tree(process.pid)
server_thread.join()
@pytest.mark.parametrize(
"generate_strategy",
[
pytest.param("sequential", marks=_RunIf(min_cuda_gpus=1)),
pytest.param("tensor_parallel", marks=_RunIf(min_cuda_gpus=2)),
],
)
def test_serve_with_generate_strategy(tmp_path, generate_strategy):
seed_everything(123)
ours_config = Config.from_name("pythia-14m")
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
ours_model = GPT(ours_config)
checkpoint_path = tmp_path / "lit_model.pth"
torch.save(ours_model.state_dict(), checkpoint_path)
config_path = tmp_path / "model_config.yaml"
with open(config_path, "w", encoding="utf-8") as fp:
yaml.dump(asdict(ours_config), fp)
# Test with generate strategy
run_command = ["litgpt", "serve", tmp_path, "--generate_strategy", generate_strategy]
process = None
def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
except subprocess.TimeoutExpired:
print("Server start-up timeout expired")
server_thread = threading.Thread(target=run_server)
server_thread.start()
_wait_and_check_response()
if process:
kill_process_tree(process.pid)
server_thread.join()
================================================
FILE: tests/test_tokenizer.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import shutil
import warnings
from types import SimpleNamespace
from unittest import mock
import pytest
from tokenizers import Tokenizer as HFTokenizer
from tokenizers.models import BPE
from transformers import AutoTokenizer
from transformers.utils import cached_file
import litgpt.config as config_module
from litgpt import PromptStyle, Tokenizer
# @pytest.mark.flaky(reruns=3, rerun_except=["AssertionError", "assert", "TypeError"])
@pytest.mark.flaky(reruns=3, reruns_delay=120)
@pytest.mark.parametrize("config", config_module.configs, ids=[c["hf_config"]["name"] for c in config_module.configs])
def test_tokenizer_against_hf(config, tmp_path):
config = config_module.Config(**config)
repo_id = f"{config.hf_config['org']}/{config.hf_config['name']}"
theirs = AutoTokenizer.from_pretrained(repo_id, token=os.getenv("HF_TOKEN"))
# create a checkpoint directory that points to the HF files
hf_files = {}
for filename in ("tokenizer.json", "generation_config.json", "tokenizer.model", "tokenizer_config.json"):
try: # download the HF tokenizer config
hf_file = cached_file(path_or_repo_id=repo_id, filename=filename)
hf_files[filename] = str(hf_file)
except Exception as ex:
warnings.warn(str(ex), RuntimeWarning)
if "tokenizer.json" not in hf_files and "tokenizer.model" not in hf_files:
raise ConnectionError("Unable to download any tokenizer files from HF")
# Create a clean, model-specific subdirectory for this test run.
# This avoids errors if previous runs or retries left files behind, ensuring the directory is always ready for fresh downloads and comparisons.
model_dir = tmp_path / config.hf_config["name"]
if model_dir.exists():
shutil.rmtree(model_dir)
os.makedirs(model_dir, exist_ok=True)
for filename, hf_file in hf_files.items():
shutil.copy(hf_file, model_dir / filename)
ours = Tokenizer(model_dir)
assert ours.vocab_size == theirs.vocab_size
if config.name == "Mixtral-8x22B-v0.1":
pytest.xfail(reason="Mixtral certainly lists 32000 vocab in its config")
else:
assert ours.vocab_size == config.vocab_size
if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ", "Qwen3")):
# even though their config defines it, it's set as None in HF
assert isinstance(ours.bos_id, int)
assert theirs.bos_token_id is None
elif config.name.startswith("Falcon3"):
if isinstance(ours.bos_id, int):
assert theirs.bos_token_id is None
else:
assert ours.bos_id == theirs.bos_token_id is None
else:
assert ours.bos_id == theirs.bos_token_id
if config.name.startswith("stablecode"):
# even though their config defines it, it's set as None in HF
assert ours.eos_id == 0
assert ours.eos_id == theirs.eos_token_id or theirs.eos_token_id is None
else:
assert ours.eos_id == theirs.eos_token_id
prompt = "Hello, readers of this test!"
prompt = PromptStyle.from_config(config).apply(prompt)
actual = ours.encode(prompt)
expected = theirs.encode(prompt)
assert actual.tolist() == expected
assert ours.decode(actual) == theirs.decode(expected, skip_special_tokens=True)
if not config.name.startswith(("Mistral", "Mixtral")):
decoded_output = "".join([ours.decode(x) for x in actual])
if ours.apply_decoding_fix and decoded_output[0] == " ":
decoded_output = decoded_output[1:] # the "hack" adds an empty space to the beginning
assert decoded_output == ours.decode(actual), type(theirs)
def test_tokenizer_input_validation():
with pytest.raises(NotADirectoryError, match="The checkpoint directory does not exist"):
Tokenizer("cocofruit")
@pytest.mark.parametrize("use_bos_by_default", (True, False))
@pytest.mark.parametrize("encode_use_bos", (None, True, False))
@pytest.mark.parametrize("encode_use_eos", (True, False))
@pytest.mark.parametrize("processor_returns_bos", (True, False))
@pytest.mark.parametrize("fake_return_ids", ([], [34, 8, 17, 2]))
def test_tokenizer_bos_eos(
tmp_path, use_bos_by_default, encode_use_bos, encode_use_eos, processor_returns_bos, fake_return_ids
):
# let `Tokenizers` create a proper (albeit empty) vocab in json format
HFTokenizer(BPE()).save(str(tmp_path / "tokenizer.json"))
tokenizer = Tokenizer(tmp_path)
tokenizer.bos_id = 0
tokenizer.eos_id = 1
tokenizer.use_bos = use_bos_by_default
if processor_returns_bos:
fake_return_ids = [tokenizer.bos_id] + fake_return_ids
fake_return_ids = SimpleNamespace(**dict(ids=fake_return_ids))
with mock.patch.object(tokenizer.processor, "encode", return_value=fake_return_ids):
tokens = tokenizer.encode("Hello world", bos=encode_use_bos, eos=encode_use_eos).tolist()
if encode_use_bos or (encode_use_bos is None and use_bos_by_default):
assert tokens[0] == tokenizer.bos_id
else:
assert not tokens or tokens[0] != tokenizer.bos_id
if encode_use_eos:
assert tokens[-1] == tokenizer.eos_id
else:
assert not tokens or tokens[-1] != tokenizer.eos_id
# both `bos` and `eos` should either not be found or occur only once at the begging (bos)
# or at the end (eos) of the tokens sequence
assert max([id for id, token in enumerate(tokens) if token == tokenizer.bos_id], default=0) == 0
assert max([id for id, token in enumerate(tokens[::-1]) if token == tokenizer.eos_id], default=0) == 0
================================================
FILE: tests/test_trainer_support.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from pathlib import Path
import lightning as L
import pytest
import torch
from litgpt.api import LLM
from litgpt.data import Alpaca2k
from litgpt.utils import _RunIf
REPO_ID = Path("EleutherAI/pythia-14m")
class LitLLM(L.LightningModule):
def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):
super().__init__()
self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)
self.trainer_ckpt_path = trainer_ckpt_path
def setup(self, stage):
self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)
def training_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("validation_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
warmup_steps = 10
optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
return [optimizer], [scheduler]
@pytest.mark.dependency()
def test_download_model():
LLM.load(model="EleutherAI/pythia-14m", distribute=None)
@pytest.mark.dependency(depends=["test_download_model"])
@_RunIf(min_cuda_gpus=1)
def test_usecase1_pretraining_from_random_weights(tmp_path):
llm = LLM.load("EleutherAI/pythia-14m", tokenizer_dir="EleutherAI/pythia-14m", init="random")
llm.save("pythia-14m-random-weights")
del llm
lit_model = LitLLM(checkpoint_dir="pythia-14m-random-weights", tokenizer_dir="EleutherAI/pythia-14m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)
trainer = L.Trainer(
max_epochs=1,
overfit_batches=2,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
text = lit_model.llm.generate("hello world")
assert isinstance(text, str)
@pytest.mark.dependency(depends=["test_download_model"])
@_RunIf(min_cuda_gpus=1)
def test_usecase2_continued_pretraining_from_checkpoint(tmp_path):
lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-14m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)
trainer = L.Trainer(
accelerator="cuda",
max_epochs=1,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
text = lit_model.llm.generate("hello world")
assert isinstance(text, str)
@pytest.mark.dependency(depends=["test_download_model", "test_usecase2_continued_pretraining_from_checkpoint"])
@_RunIf(min_cuda_gpus=1)
def test_usecase3_resume_from_trainer_checkpoint(tmp_path):
def find_latest_checkpoint(directory):
latest_checkpoint = None
latest_time = 0
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".ckpt"):
file_path = os.path.join(root, file)
file_time = os.path.getmtime(file_path)
if file_time > latest_time:
latest_time = file_time
latest_checkpoint = file_path
return latest_checkpoint
lit_model = LitLLM(
checkpoint_dir="EleutherAI/pythia-14m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs")
)
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)
trainer = L.Trainer(
accelerator="cuda",
max_epochs=1,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
text = lit_model.llm.generate("hello world")
assert isinstance(text, str)
@pytest.mark.dependency(depends=["test_download_model", "test_usecase2_continued_pretraining_from_checkpoint"])
@_RunIf(min_cuda_gpus=1)
def test_usecase4_manually_save_and_resume(tmp_path):
lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-14m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)
trainer = L.Trainer(
accelerator="cuda",
max_epochs=1,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
text = lit_model.llm.generate("hello world")
assert isinstance(text, str)
lit_model.llm.save("finetuned_checkpoint")
del lit_model
lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint")
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
text = lit_model.llm.generate("hello world")
assert isinstance(text, str)
================================================
FILE: tests/test_types.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from typing import get_args
from litgpt.constants import _SUPPORTED_LOGGERS
from litgpt.types import LoggerChoice
def test_logger_types_match_constants():
"""Ensure LoggerChoice and _SUPPORTED_LOGGERS stay synchronized."""
logger_choice_args = get_args(LoggerChoice)
assert logger_choice_args == _SUPPORTED_LOGGERS, (
f"LoggerChoice type args {logger_choice_args} != "
f"_SUPPORTED_LOGGERS {_SUPPORTED_LOGGERS}. "
f"These must stay synchronized. Update both litgpt/types.py and "
f"litgpt/constants.py when adding new loggers."
)
================================================
FILE: tests/test_utils.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from contextlib import redirect_stderr
from dataclasses import asdict
from io import StringIO
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest import mock
import pytest
import torch
import torch.nn.functional as F
import yaml
from lightning import Fabric
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.pytorch.loggers import LitLogger, MLFlowLogger, WandbLogger
from litgpt import GPT
from litgpt.args import TrainArgs
from litgpt.constants import (
_LITLOGGER_AVAILABLE,
_MLFLOW_AVAILABLE,
_MLFLOW_SKINNY_AVAILABLE,
_TENSORBOARD_AVAILABLE,
_WANDB_AVAILABLE,
)
from litgpt.parser_config import save_hyperparameters
from litgpt.utils import (
CLI,
CycleIterator,
_RunIf,
capture_hparams,
check_file_size_on_cpu_and_warn,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
find_resume_path,
fix_and_load_json,
incremental_save,
init_out_dir,
instantiate_bnb_optimizer,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
select_sft_generate_example,
)
# match fails on windows. why did they have to use backslashes?
@_RunIf(skip_windows=True)
def test_check_valid_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
out = StringIO()
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(tmp_path)
out = out.getvalue().strip()
expected = f"""
checkpoint_dir '{str(tmp_path.absolute())}' is missing the files: ['lit_model.pth', 'model_config.yaml', 'tokenizer.json OR tokenizer.model', 'tokenizer_config.json'].
Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials
See all download options by running:
litgpt download
""".strip()
assert out == expected
out = StringIO()
checkpoint_dir = tmp_path / "checkpoints" / "stabilityai" / "stablelm-base-alpha-3b"
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(checkpoint_dir)
out = out.getvalue().strip()
expected = f"""
checkpoint_dir '{str(checkpoint_dir.absolute())}' is not a checkpoint directory.
Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials
See all download options by running:
litgpt download
""".strip()
assert out == expected
out = StringIO()
checkpoint_dir.mkdir(parents=True)
foo_checkpoint_dir = tmp_path / "foo"
with pytest.raises(SystemExit), redirect_stderr(out):
check_valid_checkpoint_dir(foo_checkpoint_dir)
out = out.getvalue().strip()
expected = f"""
checkpoint_dir '{str(foo_checkpoint_dir.absolute())}' is not a checkpoint directory.
Find download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials
You have downloaded locally:
'{str(checkpoint_dir.absolute())}'
See all download options by running:
litgpt download
""".strip()
assert out == expected
def test_incremental_write(tmp_path):
sd = {str(k): torch.randn(5, 10) for k in range(3)}
sd["0"].someattr = 1
sd_expected = {k: v.clone() for k, v in sd.items()}
fn = str(tmp_path / "test.pt")
with incremental_save(fn) as f:
sd["0"] = f.store_early(sd["0"])
sd["2"] = f.store_early(sd["2"])
f.save(sd)
sd_actual = torch.load(fn)
assert sd_actual.keys() == sd_expected.keys()
assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+
for k, v_expected in sd_expected.items():
v_actual = sd_actual[k]
torch.testing.assert_close(v_expected, v_actual)
sd_actual = torch.load(fn, weights_only=True)
assert sd_actual.keys() == sd_expected.keys()
assert sd_actual["0"].someattr == 1 # requires PyTorch 2.0+
for k, v_expected in sd_expected.items():
v_actual = sd_actual[k]
torch.testing.assert_close(v_expected, v_actual)
@pytest.mark.parametrize("B", (1, 2))
@pytest.mark.parametrize("ignore_index", (None, -1, -2, -100))
def test_chunked_cross_entropy(ignore_index, B):
V = 50
T = 25
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))
if ignore_index is not None:
targets[:, [1, 4, 10, 19]] = ignore_index
baseline_loss = F.cross_entropy(
regular_logits.reshape(-1, regular_logits.size(-1)),
targets.reshape(-1),
ignore_index=(ignore_index if ignore_index is not None else -100),
)
ignore_index = ignore_index if ignore_index is not None else -100
regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0, ignore_index=ignore_index)
assert torch.equal(baseline_loss, regular_loss)
assert regular_loss.numel() == 1
chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
logit_chunk_size = 6
assert T % logit_chunk_size != 0 # ensure leftover
chunked_logits = list(regular_logits.split(logit_chunk_size, dim=1))
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10, ignore_index=ignore_index)
torch.testing.assert_close(chunked_loss, regular_loss)
torch.testing.assert_close(chunked_loss, baseline_loss)
def test_num_parameters():
model = torch.nn.Linear(2, 2)
assert num_parameters(model) == 6
assert num_parameters(model, requires_grad=True) == 6
assert num_parameters(model, requires_grad=False) == 0
model = torch.nn.Linear(2, 2)
model.bias.requires_grad = False
assert num_parameters(model) == 6
assert num_parameters(model, requires_grad=True) == 4
assert num_parameters(model, requires_grad=False) == 2
@_RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("mode", ["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"])
def test_num_parameters_bitsandbytes(mode):
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin, accelerator="cuda", devices=1)
model = torch.nn.Linear(10, 10)
model = fabric.setup(model)
assert num_parameters(model) == 110
with fabric.init_module(empty_init=True):
model = GPT.from_name("pythia-14m")
assert num_parameters(model) == 14067712
def test_cycle_iterator():
iterator = CycleIterator([])
with pytest.raises(StopIteration):
next(iterator)
iterator = CycleIterator(range(3))
assert iterator.epoch == 0
assert next(iterator) == 0
assert iterator.epoch == 0
assert next(iterator) == 1
assert iterator.epoch == 0
assert next(iterator) == 2
assert iterator.epoch == 0
assert next(iterator) == 0
assert iterator.epoch == 1
def test_parse_devices():
with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
assert parse_devices(0)
with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
assert parse_devices(-2)
with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=0):
assert parse_devices("auto") == 1 # CPU
assert parse_devices(10) == 10 # leave validation up to Fabric later on
with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=1):
assert parse_devices("auto") == 1 # CUDA
with mock.patch("litgpt.utils.torch.cuda.device_count", return_value=3):
assert parse_devices("auto") == 3
assert parse_devices(-1) == 3
assert parse_devices(5) == 5
def test_copy_config_files(fake_checkpoint_dir, tmp_path):
copy_config_files(fake_checkpoint_dir, tmp_path)
expected = {"model_config.yaml", "tokenizer_config.json", "tokenizer.json"}
contents = set(os.listdir(tmp_path))
assert expected.issubset(contents)
def test_capture_hparams():
integer = 1
string = "string"
boolean = True
none = None
path = Path("/path")
dataclass = TrainArgs()
other = torch.nn.Linear(1, 1)
hparams = capture_hparams()
assert hparams == {
"integer": integer,
"string": string,
"boolean": boolean,
"none": none,
"path": path,
"dataclass": asdict(dataclass),
"other": str(other),
}
def _test_function(out_dir: Path, foo: bool = False, bar: int = 1):
save_hyperparameters(_test_function, out_dir)
def test_save_hyperparameters(tmp_path):
with mock.patch("sys.argv", ["any.py", str(tmp_path), "--foo", "True"]):
CLI(_test_function)
with open(tmp_path / "hyperparameters.yaml", encoding="utf-8") as file:
hparams = yaml.full_load(file)
assert hparams["out_dir"] == str(tmp_path)
assert hparams["foo"] is True
assert hparams["bar"] == 1
def _test_function2(out_dir: Path, foo: bool = False, bar: int = 1):
assert False, "I only exist as a signature, but I should not run."
@pytest.mark.parametrize(
"command",
[
"any.py",
"litgpt finetune",
"litgpt finetune_full",
"litgpt finetune_lora",
"litgpt finetune_adapter",
"litgpt finetune_adapter_v2",
"litgpt pretrain",
],
)
def test_save_hyperparameters_known_commands(command, tmp_path):
with mock.patch("sys.argv", [*command.split(" "), str(tmp_path), "--foo", "True"]):
save_hyperparameters(_test_function2, tmp_path)
with open(tmp_path / "hyperparameters.yaml", encoding="utf-8") as file:
hparams = yaml.full_load(file)
assert hparams["out_dir"] == str(tmp_path)
assert hparams["foo"] is True
assert hparams["bar"] == 1
def test_choose_logger(tmp_path):
assert isinstance(choose_logger("csv", out_dir=tmp_path, name="csv"), CSVLogger)
if _TENSORBOARD_AVAILABLE:
assert isinstance(choose_logger("tensorboard", out_dir=tmp_path, name="tb"), TensorBoardLogger)
if _WANDB_AVAILABLE:
assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger)
if _MLFLOW_AVAILABLE or _MLFLOW_SKINNY_AVAILABLE:
assert isinstance(choose_logger("mlflow", out_dir=tmp_path, name="wandb"), MLFlowLogger)
if _LITLOGGER_AVAILABLE:
assert isinstance(choose_logger("litlogger", out_dir=tmp_path, name="litlogger"), LitLogger)
with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
choose_logger("foo", out_dir=tmp_path, name="foo")
@pytest.mark.parametrize(
"path_type, input_path, expected",
[
("relative", "some/relative/path", "some/relative/path"),
("absolute", "/usr/absolute/path", "/usr/absolute/path"),
("env_relative", "some/relative/path", "prefix/some/relative/path"),
("env_absolute", "/usr/absolute/path", "/usr/absolute/path"),
],
)
def test_init_out_dir(path_type, input_path, expected):
if path_type.startswith("env_"):
with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}):
result = init_out_dir(input_path)
assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})"
else:
result = init_out_dir(input_path)
if "LIGHTNING_ARTIFACTS_DIR" not in os.environ:
assert result == Path(expected), f"Failed for {path_type} with input {input_path} (result {result})"
else:
assert result == Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / expected, (
f"Failed for {path_type} with input {input_path} (result {result})"
)
def test_find_resume_path(tmp_path):
assert find_resume_path(resume=None, out_dir=Path("does/not/exist")) is None
assert find_resume_path(resume=Path("does/not/exist"), out_dir=Path("does/not/matter")) == Path("does/not/exist")
assert find_resume_path(resume=(tmp_path / "checkpoint.pt"), out_dir=Path("does/not/matter")) == (
tmp_path / "checkpoint.pt"
)
# `resume='auto'` does not enforce the checkpoint to exist
assert find_resume_path(resume="auto", out_dir=Path("does/not/exist")) is None
# `resume=True` requires a checkpoint to exist
with pytest.raises(FileNotFoundError, match="You passed `--resume=True`, but no checkpoint file was found"):
find_resume_path(resume=True, out_dir=Path("does/not/exist"))
with pytest.raises(FileNotFoundError, match="You passed `--resume=True`, but no checkpoint file was found"):
find_resume_path(resume=True, out_dir=tmp_path)
(tmp_path / "step-001").mkdir()
(tmp_path / "step-001" / "lit_model.pth").touch()
(tmp_path / "step-002").mkdir()
(tmp_path / "step-002" / "lit_model.pth").touch()
(tmp_path / "step-003").mkdir()
(tmp_path / "step-003" / "lit_model.pth").touch()
assert find_resume_path(resume=True, out_dir=tmp_path) == (tmp_path / "step-003" / "lit_model.pth")
assert find_resume_path(resume="auto", out_dir=tmp_path) == (tmp_path / "step-003" / "lit_model.pth")
@pytest.fixture
def model_parameters():
return [torch.nn.Parameter(torch.randn(2, 2))]
def test_instantiate_bnb_optimizer_with_str(model_parameters):
import bitsandbytes as bnb
with mock.patch("litgpt.utils.get_argument_names", return_value={"lr", "eps", "weight_decay"}):
optimizer = instantiate_bnb_optimizer("AdamW", model_parameters)
assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW)
def test_instantiate_bnb_optimizer_with_dict(model_parameters):
import bitsandbytes as bnb
optimizer_dict = {"class_path": "AdamW", "init_args": {"lr": 0.01}}
with mock.patch("litgpt.utils.get_argument_names", return_value={"lr", "eps", "weight_decay"}):
optimizer = instantiate_bnb_optimizer(optimizer_dict, model_parameters)
assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW)
assert optimizer.param_groups[0]["lr"] == 0.01
def test_instantiate_bnb_optimizer_with_invalid_str(model_parameters):
with pytest.raises(ValueError, match="only supports the AdamW"):
instantiate_bnb_optimizer("SGD", model_parameters)
def test_instantiate_torch_optimizer_with_str(model_parameters):
optimizer = instantiate_torch_optimizer("Adam", model_parameters, lr=0.01)
assert isinstance(optimizer, torch.optim.Adam)
assert optimizer.param_groups[0]["lr"] == 0.01
def test_instantiate_torch_optimizer_with_class(model_parameters):
optimizer = instantiate_torch_optimizer(
{"class_path": "torch.optim.Adam", "init_args": {"lr": 123}}, model_parameters, lr=0.02
)
assert isinstance(optimizer, torch.optim.Adam)
# init args gets overridden
assert optimizer.param_groups[0]["lr"] == 0.02
@pytest.mark.parametrize(
"input_path, expected",
[
(Path("checkpoints/my_model"), Path("checkpoints/my_model")),
(Path("checkpoints/my_model"), Path("./checkpoints/my_model")),
],
)
def test_extend_checkpoint_dir_is_prefixed(input_path, expected):
original_dir = Path.cwd() # Save the current directory
with TemporaryDirectory() as tmp_dir:
os.chdir(tmp_dir)
try:
if not input_path.is_absolute():
input_path = Path(tmp_dir) / input_path
if not expected.is_absolute():
expected = Path(tmp_dir) / expected
input_path.parent.mkdir(parents=True, exist_ok=True)
input_path.touch(exist_ok=True)
assert extend_checkpoint_dir(input_path) == expected
finally:
os.chdir(original_dir) # Reset the current directory
@pytest.mark.parametrize(
"input_path, expected",
[
(Path("my_model"), Path("checkpoints/my_model")),
(Path("my_model"), Path("./checkpoints/my_model")),
],
)
def test_extend_checkpoint_dir(input_path, expected):
original_dir = Path.cwd() # Save the current directory
with TemporaryDirectory() as tmp_dir:
os.chdir(tmp_dir)
try:
if not input_path.is_absolute():
input_path = Path(tmp_dir) / "checkpoints" / input_path
if not expected.is_absolute():
expected = Path(tmp_dir) / expected
input_path.parent.mkdir(parents=True, exist_ok=True)
input_path.touch(exist_ok=True)
assert extend_checkpoint_dir(input_path) == expected
finally:
os.chdir(original_dir) # Reset the current directory
@pytest.mark.parametrize(
"input_path, expected",
[
(Path("my_model"), Path("my_model")),
(Path("/my_model"), Path("/my_model")),
],
)
def test_extend_checkpoint_dir_dont_exist(input_path, expected):
assert extend_checkpoint_dir(input_path) == expected
def test_file_size_below_limit_on_cpu():
# Test file size below limit on CPU
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_000_000_000):
size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu")
assert size == 4_000_000_000
def test_file_size_above_limit_on_cpu():
# Test file size above limit on CPU
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_600_000_000):
with pytest.warns(UserWarning) as record:
size = check_file_size_on_cpu_and_warn(temp_file.name, "cpu")
assert size == 4_600_000_000
assert "over 4.2 GB" in str(record[0].message)
def test_file_size_above_limit_on_gpu():
# Test file size above limit on GPU should not warn
with NamedTemporaryFile() as temp_file:
with mock.patch("os.path.getsize", return_value=4_600_000_000):
size = check_file_size_on_cpu_and_warn(temp_file.name, "gpu")
assert size == 4_600_000_000
@pytest.fixture
def mock_cuda_is_available_true(monkeypatch):
"""Fixture to mock torch.cuda.is_available() to return True."""
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
@pytest.fixture
def mock_nvidia_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA RTX A6000"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
@pytest.fixture
def mock_amd_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for AMD GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "AMD Instinct MI250X"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
@pytest.fixture
def all_nvlink_connected_output():
return mock.MagicMock(
stdout=""" GPU0 GPU1 GPU2 GPU3
GPU0 X NV12 NV12 NV12
GPU1 NV12 X NV12 NV12
GPU2 NV12 NV12 X NV12
GPU3 NV12 NV12 NV12 X""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_all_nvlink_connected(
mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
@pytest.fixture
def nvlink_partially_connected_output():
return mock.MagicMock(
stdout=""" GPU0 GPU1 GPU2 GPU3 CPU Affinity
GPU0 X NV1 SYS SYS 0-7
GPU1 NV1 X SYS SYS 0-7
GPU2 SYS SYS X NV1 8-15
GPU3 SYS SYS NV1 X 8-15
Legend:
X = Self
NV1 = Connected via NVLink with 1 hop
SYS = Connected via the PCIe or CPU subsystem""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_nvlink_partially_connected_output(
mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_partially_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
@pytest.fixture
def nvlink_not_connected_output():
return mock.MagicMock(
stdout=""" GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PHB PHB PHB 0-47 0 N/A
GPU1 PHB X PHB PHB 0-47 0 N/A
GPU2 PHB PHB X PHB 0-47 0 N/A
GPU3 PHB PHB PHB X 0-47 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_nvlink_not_connected_output(
mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_not_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
@pytest.fixture
def nvlink_all_gpu_connected_but_other_connected_output():
return mock.MagicMock(
stdout=""" GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 NIC9 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A
GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A
GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A
GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS 0-63,128-191 0 N/A
GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12 SYS SYS SYS SYS SYS SYS SYS SYS PXB PXB 64-127,192-254 1 N/A
GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 SYS SYS SYS SYS SYS SYS SYS SYS PXB PXB 64-127,192-254 1 N/A
GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A
GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A
NIC0 SYS SYS PXB PXB SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS SYS SYS SYS
NIC1 SYS SYS PXB PXB SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS SYS SYS SYS
NIC2 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS SYS SYS SYS
NIC3 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS PXB X SYS SYS SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS X PXB SYS SYS SYS SYS
NIC5 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS PXB X SYS SYS SYS SYS
NIC6 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PIX SYS SYS
NIC7 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PIX X SYS SYS
NIC8 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PXB
NIC9 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PXB X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
NIC8: mlx5_8
NIC9: mlx5_9
""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run,
nvlink_all_gpu_connected_but_other_connected_output,
mock_cuda_is_available_true,
mock_nvidia_device_properties,
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
@pytest.fixture
def nvidia_smi_nvlink_output_dual_gpu_no_numa():
return mock.MagicMock(
stdout="""
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV1 0-15 0 N/A
GPU1 NV1 X 0-15 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
mock_run, nvidia_smi_nvlink_output_dual_gpu_no_numa, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvidia_smi_nvlink_output_dual_gpu_no_numa
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
@pytest.fixture
def rocm_smi_xgmi_output_multi_gpu():
"""
rocm-smi --showtopotype on ROCm 6.0.3+
"""
return mock.MagicMock(
stdout="""
=============================== ROCm System Management Interface ============================
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
================================== End of ROCm SMI Log ===================================
""",
returncode=0,
)
@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(
mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties
):
mock_run.return_value = rocm_smi_xgmi_output_multi_gpu
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")
@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")
@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(
mock_run, monkeypatch, mock_cuda_is_available_true
):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU")
def test_fix_and_load_json():
# Test 1: Invalid JSON string with a trailing comma
invalid_json_trailing_comma = """
{
"_from_model_config": true,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0",
"do_sample": true,
"temperature": 0.6,
"top_p": 0.9,
}
"""
expected_output_trailing_comma = {
"_from_model_config": True,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0",
"do_sample": True,
"temperature": 0.6,
"top_p": 0.9,
}
result_trailing_comma = fix_and_load_json(invalid_json_trailing_comma)
assert result_trailing_comma == expected_output_trailing_comma
# Test 2: Invalid JSON string with missing commas between properties
invalid_json_missing_commas = """
{
"_from_model_config": true,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0"
"do_sample": true,
"temperature": 0.6,
"top_p": 0.9,
}
"""
expected_output_missing_commas = {
"_from_model_config": True,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0",
"do_sample": True,
"temperature": 0.6,
"top_p": 0.9,
}
result_missing_commas = fix_and_load_json(invalid_json_missing_commas)
assert result_missing_commas == expected_output_missing_commas
def test_select_sft_generate_example():
eval_mock = mock.MagicMock()
data_mock = mock.MagicMock()
test_dataset = {"data": [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]}
train_dataset = {"data": [{"instruction": "Train instruction 1"}, {"instruction": "Train instruction 2"}]}
data_mock.test_dataset.data = test_dataset["data"]
data_mock.train_dataset.data = train_dataset["data"]
# Test "first" instruction from test dataset
eval_mock.evaluate_example = "first"
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 1"
# Test "first" instruction from train dataset when test dataset is empty
data_mock.test_dataset.data = []
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 1"
# Test random selection from test dataset
eval_mock.evaluate_example = "random"
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]
with mock.patch("random.randint", return_value=1):
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 2"
# Test random selection from train dataset when test dataset is empty
data_mock.test_dataset.data = []
with mock.patch("random.randint", return_value=1):
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 2"
# Test specific index from test dataset
eval_mock.evaluate_example = 1
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 2"
# Test specific index from train dataset when test dataset has fewer elements
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}]
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 2"
# Test out-of-range index
eval_mock.evaluate_example = 2
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}]
data_mock.train_dataset.data = [{"instruction": "Train instruction 1"}]
with pytest.raises(IndexError):
select_sft_generate_example(eval_mock, data_mock)
# Test unknown evaluation type
eval_mock.evaluate_example = "unknown"
with pytest.raises(ValueError):
select_sft_generate_example(eval_mock, data_mock)
================================================
FILE: tests/test_yarn.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import pytest
import torch
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM
from litgpt import Config
from litgpt.model import Block
@torch.inference_mode()
@pytest.mark.parametrize("batch_size", (1, 2))
@pytest.mark.parametrize("seq_len", (8, 16))
@pytest.mark.parametrize("device", [torch.device("cpu")])
def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device):
"""Test DeepSeek V3 block (attention + MLP + norms) with YaRN RoPE scaling - litgpt vs hf"""
# Use layer_idx=0 to test dense MLP instead of MoE
layer_idx = 0
# YaRN configuration
yarn_config = dict(
factor=8.0,
beta_fast=32.0,
beta_slow=1.0,
original_max_seq_len=4096,
mscale=1.0,
mscale_all_dim=0.8,
)
config_litgpt = Config(
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
norm_eps=1e-6,
norm_class_name="RMSNorm",
bias=False,
parallel_residual=False,
mlp_class_name="LLaMAMoE",
intermediate_size=128,
rope_interleave=True,
rope_adjustments=yarn_config, # YaRN config
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
first_k_dense_replace=3, # Use dense MLP for first 3 layers
)
# HF config with YaRN
rope_parameters = {
"type": "yarn",
"rope_theta": 10000.0,
"factor": yarn_config["factor"],
"beta_fast": yarn_config["beta_fast"],
"beta_slow": yarn_config["beta_slow"],
"original_max_position_embeddings": yarn_config["original_max_seq_len"],
"mscale": yarn_config["mscale"],
"mscale_all_dim": yarn_config["mscale_all_dim"],
}
config_hf = DeepseekV3Config(
padded_vocab_size=10000,
num_hidden_layers=1,
vocab_size=10000,
hidden_size=64,
intermediate_size=128,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=16,
qk_rope_head_dim=8,
qk_nope_head_dim=8,
v_head_dim=16,
rope_interleave=True,
first_k_dense_replace=3,
rms_norm_eps=1e-6,
rope_scaling=rope_parameters, # YaRN config
)
# Debug: Check if HF config has rope_parameters
print("\n=== HF Config Debug ===")
print(f"config_hf.rope_parameters: {config_hf.rope_scaling}")
block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device)
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
block_hf = model_hf.model.layers[layer_idx]
block_litgpt.eval()
block_hf.eval()
sync_block_weights(block_litgpt, block_hf)
hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)
# Prepare RoPE sin/cos tables using YaRN computation
from litgpt.model import build_rope_cache
rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"]
# Build YaRN RoPE cache for LitGPT
cos_litgpt, sin_litgpt = build_rope_cache(
seq_len=seq_len,
n_elem=rope_head_dim,
device=device,
base=config_litgpt.rope_base,
extra_config={
"factor": yarn_config["factor"],
"beta_fast": yarn_config["beta_fast"],
"beta_slow": yarn_config["beta_slow"],
"original_max_seq_len": yarn_config["original_max_seq_len"],
"mscale": yarn_config["mscale"],
"mscale_all_dim": yarn_config["mscale_all_dim"],
},
)
# Get YaRN RoPE embeddings from HF (rotary_emb is on model level, not layer level)
rotary_emb = model_hf.model.rotary_emb
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
cos_hf, sin_hf = rotary_emb(hidden_states, position_ids)
# Expand dimensions for batch and broadcast
cos_litgpt = cos_litgpt.unsqueeze(0).expand(batch_size, -1, -1)
sin_litgpt = sin_litgpt.unsqueeze(0).expand(batch_size, -1, -1)
# Compare RoPE embeddings first
print("\n=== RoPE Embeddings Comparison ===")
print(f"LitGPT cos/sin shape: {cos_litgpt.shape}, {sin_litgpt.shape}")
print(f"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}")
print(f"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}")
print(f"Sin max diff: {(sin_litgpt - sin_hf).abs().max()}")
print(f"\nLitGPT cos sample [0,0,:4]: {cos_litgpt[0, 0, :4]}")
print(f"HF cos sample [0,0,:4]: {cos_hf[0, 0, :4]}")
print(f"LitGPT cos min/max: {cos_litgpt.min():.4f} / {cos_litgpt.max():.4f}")
print(f"HF cos min/max: {cos_hf.min():.4f} / {cos_hf.max():.4f}")
# Check inv_freq from both
print("\n=== Checking inv_freq ===")
print(f"HF rotary_emb.inv_freq shape: {rotary_emb.inv_freq.shape}")
print(f"HF inv_freq: {rotary_emb.inv_freq}")
print(f"HF attention_scaling: {rotary_emb.attention_scaling}")
# Use the same embeddings for both (LitGPT's)
cos = cos_litgpt
sin = sin_litgpt
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
# Run forward passes
output_litgpt = block_litgpt(hidden_states, cos, sin)
output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)
if isinstance(output_hf, tuple):
output_hf = output_hf[0]
max_diff = (output_litgpt - output_hf).abs().max()
print("\n=== DEBUG INFO ===")
print(f"Max diff: {max_diff}")
print(f"Output litgpt mean: {output_litgpt.mean()}, std: {output_litgpt.std()}")
print(f"Output hf mean: {output_hf.mean()}, std: {output_hf.std()}")
print(f"Cos/sin shape: {cos.shape}, {sin.shape}")
print(f"Hidden states shape: {hidden_states.shape}")
# Check if the issue is in attention or MLP
if hasattr(output_litgpt, "shape") and hasattr(output_hf, "shape"):
if output_litgpt.shape != output_hf.shape:
print(f"Shape mismatch! litgpt: {output_litgpt.shape}, hf: {output_hf.shape}")
assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), f"FAILED: Max diff: {max_diff}"
def sync_weights(litgpt_model, hf_model):
"""Copies weights from lit-gpt model to HF model."""
print("Synchronizing weights...")
with torch.no_grad():
hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight)
hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight)
hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight)
hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight)
hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight)
hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight)
hf_model.o_proj.weight.copy_(litgpt_model.proj.weight)
print("Synchronization complete.")
def sync_block_weights(block_litgpt, block_hf):
"""Synchronize all weights from LitGPT block to HF block."""
print("Synchronizing block weights...")
with torch.no_grad():
# Sync attention weights
sync_weights(block_litgpt.attn, block_hf.self_attn)
# Sync MLP weights (assumes dense MLP, not MoE)
block_hf.mlp.gate_proj.weight.copy_(block_litgpt.mlp.fc_1.weight)
block_hf.mlp.up_proj.weight.copy_(block_litgpt.mlp.fc_2.weight)
block_hf.mlp.down_proj.weight.copy_(block_litgpt.mlp.proj.weight)
# Sync normalization layers
block_hf.input_layernorm.weight.copy_(block_litgpt.norm_1.weight)
block_hf.post_attention_layernorm.weight.copy_(block_litgpt.norm_2.weight)
print("Block synchronization complete.")
================================================
FILE: tutorials/0_to_litgpt.md
================================================
# Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs
This tutorial walks you through the main features and usage patterns for ⚡️LitGPT, a library for pretraining, finetuning, and using LLMs that focuses on an efficient user experience while being developer-friendly.
The topics, following the installation of LitGPT, are in chronological order, reflecting the steps in an LLM lifecycle: Pretraining → Finetuning → Inference.
However, it is also possible, and even common, to use and deploy models with LitGPT without pretraining and finetuning. So, if you are not interested in pretraining and finetuning, please feel free to skip these sections.
## Install LitGPT
LitGPT is available as a Python library from the PyPI package repository, and we recommend installing it using Python's `pip` installer module, including all required package dependencies:
```bash
pip install 'litgpt[all]'
```
Alternatively, if you are a researcher or developer planning to make changes to LitGPT, you can clone the GitHub repository and install it from a local folder as follows:
```
git clone https://github.com/Lightning-AI/litgpt.git
cd litgpt
pip install -e '.[all]'
```
## Pretrain LLMs
Pretraining LLMs requires substantial compute resources and time commitment. For that reason, most researchers and practitioners prefer to skip this step and continue with the *Download pretrained model weights* section instead.
However, if you feel adventurous and want to pretrain your own LLM, here's how.
First, we have to decide which type of model architecture we want to use. We list the available architectures by using the `pretrain` command without any additional arguments:
```bash
litgpt pretrain list
```
This prints a list of all available model architectures in alphabetical order:
```
Camel-Platypus2-13B
Camel-Platypus2-70B
CodeLlama-13b-Python-hf
...
EleutherAI/pythia-410m
...
vicuna-13b-v1.3
vicuna-13b-v1.5
vicuna-13b-v1.5-16k
vicuna-33b-v1.3
vicuna-7b-v1.3
vicuna-7b-v1.5
vicuna-7b-v1.5-16k
```
Suppose we want to pretraining the 1.1B parameter small `tiny-llama-1.1b` model. Before starting finetuning, we must also choose and download a tokenizer.
We can download a tokenizer via the `download` command. Note that running `litgpt download list` will also print a list of all available models and tokenizers to download.
To filter for specific models, e.g., TinyLlama, we can use the `grep` command in our terminal:
```bash
litgpt download list | grep TinyLlama
```
This prints
```
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
TinyLlama/TinyLlama-1.1B-Chat-v1.0
```
Let's now download the tokenizer corresponding to `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T` that we can then use to pretrain the TinyLlama model:
```
litgpt download \
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
--tokenizer_only true
```
(when specified)
Next, we can pretrain the model on the OpenWebText dataset with the default setting as follows:
```bash
litgpt pretrain tiny-llama-1.1b \
--data OpenWebText \
--tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
```
If you are interested in additional settings, you can use the help command as follows:
```
litgpt pretrain --help
```
> [!TIP]
> Above, we only covered the most basic commands for pretraining a model using LitGPT. We highly recommend checking the resources below if you are interested in pretraining a model.
**More information and additional resources**
- [tutorials/pretrain](./pretrain.md): General information about pretraining in LitGPT
- [tutorials/pretrain_tinyllama](./pretrain_tinyllama.md): A tutorial for finetuning a 1.1B TinyLlama model on 3 trillion tokens
- [config_hub/pretrain](../config_hub/pretrain): Pre-made config files for pretraining that work well out of the box
- Project templates in reproducible environments with multi-GPU and multi-node support:
- [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
- [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)
- [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)
## Download pretrained model weights
Most practical use cases, like LLM inference (/chat) or finetuning, involve using pretrained model weights. LitGPT supports a large number of model weights, which can be listed by executing the `download` with `list` as an argument:
```bash
litgpt download list
```
This will print a (long) list of all supported pretrained models (abbreviated for readability below):
```
..
google/gemma-2b
...
meta-llama/Llama-2-7b-hf
...
microsoft/phi-2
...
mistralai/Mixtral-8x7B-Instruct-v0.1
...
```
To download the model weights, provide one of the model strings above as input argument:
```bash
litgpt download microsoft/phi-2
```
```
model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s]
model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s]
...
Converting checkpoint files to LitGPT format.
Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin
...
Saving converted checkpoint to checkpoints/microsoft/phi-2
```
> [!TIP]
> Note that some models, such as Llama 2, require that you accept Meta AI's terms of service for this model, and you need to use a special access token via the `litgpt download ... --access_token ...` option. For more information, visit the respective Model Hub website, e.g., [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf). The access token can be created under your Model Hub in the `Profile > Access Tokens` menu.
By default, the weights are going to be stored in a `./checkpoints` subdirectory:
```bash
ls -lh checkpoints/microsoft/phi-2/
```
```
total 11G
-rw-r--r-- 1 sebastian sebastian 863 Mar 19 21:14 config.json
-rw-r--r-- 1 sebastian sebastian 124 Mar 19 21:14 generation_config.json
-rw-r--r-- 1 sebastian sebastian 5.2G Mar 19 21:15 lit_model.pth
-rw-r--r-- 1 sebastian sebastian 4.7G Mar 19 21:15 model-00001-of-00002.bin
-rw-r--r-- 1 sebastian sebastian 538M Mar 19 21:15 model-00002-of-00002.bin
-rw-r--r-- 1 sebastian sebastian 528 Mar 19 21:15 model_config.yaml
-rw-r--r-- 1 sebastian sebastian 2.1M Mar 19 21:14 tokenizer.json
-rw-r--r-- 1 sebastian sebastian 7.2K Mar 19 21:14 tokenizer_config.json
```
The model is now ready for inference and chat, for example, using the `chat` command on the checkpoint directory:
```bash
litgpt chat microsoft/phi-2
```
```
Now chatting with phi-2.
To exit, press 'Enter' on an empty prompt.
Seed set to 1234
>> Prompt: Why are LLMs so useful?
>> Reply: When building applications or operating systems, you can use LLMs to know how a computer should respond to your commands. This can make your programs run faster and more efficiently.
Time for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens
>> Prompt:
```
> [!TIP]
> Use `--multiline true` to support prompts that require multiple input lines.
**More information and additional resources**
- [tutorials/download_model_weights](download_model_weights.md): A more comprehensive download tutorial, tips for GPU memory limitations, and more
## Finetune LLMs
LitGPT supports several methods of supervised instruction finetuning, which allows you to finetune models to follow instructions.
Datasets for Instruction-finetuning are usually formatted in the following way:
Alternatively, datasets for instruction finetuning can also contain an `'input'` field:
In an instruction-finetuning context, "full" finetuning means updating all model parameters as opposed to only a subset. Adapter and LoRA (short for low-rank adaptation) are methods for parameter-efficient finetuning that only require updating a small fraction of the model weights.
Parameter-efficient finetuning is much more resource-efficient and cheaper than full finetuning, and it often results in the same good performance on downstream tasks.
In the following example, we will use LoRA for finetuning, which is one of the most popular LLM finetuning methods. (For more information on how LoRA works, please see [Code LoRA from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).)
Before we start, we have to download a model as explained in the previous "Download pretrained model" section above:
```bash
litgpt download microsoft/phi-2
```
The LitGPT interface can be used via command line arguments and configuration files. We recommend starting with the configuration files from the [config_hub](../config_hub) and either modifying them directly or overriding specific settings via the command line. For example, we can use the following setting to train the downloaded 2.7B parameter `microsoft/phi-2` model, where we set `--max_steps 5` for a quick test run.
If you have downloaded or cloned the LitGPT repository, you can provide the `config` file via a relative path:
```bash
litgpt finetune_lora microsoft/phi-2\
--config config_hub/finetune/phi-2/lora.yaml \
--train.max_steps 5
```
Alternatively, you can provide a URL:
```bash
litgpt finetune_lora microsoft/phi-2\
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/phi-2/lora.yaml \
--train.max_steps 5
```
> [!TIP]
> Note that the config file above will finetune the model on the `Alpaca2k` dataset on 1 GPU and save the resulting files in an `out/finetune/lora-phi-2` directory. All of these settings can be changed via a respective command line argument or by changing the config file.
> To see more options, execute `litgpt finetune_lora --help`.
Running the previous finetuning command will initiate the finetuning process, which should only take about a minute on a GPU due to the `--train.max_steps 5` setting.
```
{'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'), # TODO
'data': Alpaca2k(mask_prompt=False,
val_split_fraction=0.03847,
prompt_style=,
ignore_index=-100,
seed=42,
num_workers=4,
download_dir=PosixPath('data/alpaca2k')),
'devices': 1,
'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
'logger_name': 'csv',
'lora_alpha': 16,
'lora_dropout': 0.05,
'lora_head': True,
'lora_key': True,
'lora_mlp': True,
'lora_projection': True,
'lora_query': True,
'lora_r': 8,
'lora_value': True,
'num_nodes': 1,
'out_dir': PosixPath('out/finetune/lora-phi-2'),
'precision': 'bf16-true',
'quantize': None,
'seed': 1337,
'train': TrainArgs(save_interval=800,
log_interval=1,
global_batch_size=8,
micro_batch_size=4,
lr_warmup_steps=10,
epochs=1,
max_tokens=None,
max_steps=5,
max_seq_length=512,
tie_embeddings=None,
learning_rate=0.0002,
weight_decay=0.0,
beta1=0.9,
beta2=0.95,
max_norm=None,
min_lr=6e-05)}
Seed set to 1337
Number of trainable parameters: 12,226,560
Number of non-trainable parameters: 2,779,683,840
The longest sequence length in the train data is 512, the model's maximum sequence length is 512 and context length is 2048
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.
### Response:
I recommend you watch "Parasite" because it's a critically acclaimed movie that won multiple awards, including the Academy Award for Best Picture. It's a thought-provoking and suspenseful film that will keep you on the edge of your seat. The movie also tackles social and economic inequalities, making it a must-watch for anyone interested in meaningful storytelling.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MeanMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs) # noqa: B028
Missing logger folder: out/finetune/lora-phi-2/logs/csv
Epoch 1 | iter 1 step 0 | loss train: 1.646, val: n/a | iter time: 820.31 ms
Epoch 1 | iter 2 step 1 | loss train: 1.660, val: n/a | iter time: 548.72 ms (step)
Epoch 1 | iter 3 step 1 | loss train: 1.687, val: n/a | iter time: 300.07 ms
Epoch 1 | iter 4 step 2 | loss train: 1.597, val: n/a | iter time: 595.27 ms (step)
Epoch 1 | iter 5 step 2 | loss train: 1.640, val: n/a | iter time: 260.75 ms
Epoch 1 | iter 6 step 3 | loss train: 1.703, val: n/a | iter time: 568.22 ms (step)
Epoch 1 | iter 7 step 3 | loss train: 1.678, val: n/a | iter time: 511.70 ms
Epoch 1 | iter 8 step 4 | loss train: 1.741, val: n/a | iter time: 514.14 ms (step)
Epoch 1 | iter 9 step 4 | loss train: 1.689, val: n/a | iter time: 423.59 ms
Epoch 1 | iter 10 step 5 | loss train: 1.524, val: n/a | iter time: 603.03 ms (step)
Training time: 11.20s
Memory used: 13.90 GB
Saving LoRA weights to 'out/finetune/lora-phi-2/final/lit_model.pth.lora'
Saved merged weights to 'out/finetune/lora-phi-2/final/lit_model.pth'
```
Notice that the LoRA script saves both the LoRA weights (`'out/finetune/lora-phi-2/final/lit_model.pth.lora'`) and the LoRA weight merged back into the original model (`'out/finetune/lora-phi-2/final/lit_model.pth'`) for convenience. This allows us to use the finetuned model via the `chat` function directly:
```bash
litgpt chat out/finetune/lora-phi-2/final/
```
```
Now chatting with phi-2.
To exit, press 'Enter' on an empty prompt.
Seed set to 1234
>> Prompt: Why are LLMs so useful?
>> Reply: LLMs are useful because they can be trained to perform various natural language tasks, such as language translation, text generation, and question-answering. They are also able to understand the context of the input data, which makes them particularly useful for tasks such as sentiment analysis and text summarization. Additionally, because LLMs can learn from large amounts of data, they are able to generalize well and perform well on new data.
Time for inference: 2.15 sec total, 39.57 tokens/sec, 85 tokens
>> Prompt:
```
**More information and additional resources**
- [tutorials/prepare_dataset](prepare_dataset.md): A summary of all out-of-the-box supported datasets in LitGPT and utilities for preparing custom datasets
- [tutorials/finetune](finetune.md): An overview of the different finetuning methods supported in LitGPT
- [tutorials/finetune_full](finetune_full.md): A tutorial on full-parameter finetuning
- [tutorials/finetune_lora](finetune_lora.md): Options for parameter-efficient finetuning with LoRA and QLoRA
- [tutorials/finetune_adapter](finetune_adapter.md): A description of the parameter-efficient Llama-Adapter methods supported in LitGPT
- [tutorials/oom](oom.md): Tips for dealing with out-of-memory (OOM) errors
- [config_hub/finetune](../config_hub/finetune): Pre-made config files for finetuning that work well out of the box
## LLM inference
To use a downloaded or finetuned model for chat, you only need to provide the corresponding checkpoint directory containing the model and tokenizer files. For example, to chat with the phi-2 model from Microsoft, download it as follows, as described in the "Download pretrained model" section:
```bash
litgpt download microsoft/phi-2
```
```
model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s]
model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s]
...
Converting checkpoint files to LitGPT format.
Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin
...
Saving converted checkpoint to checkpoints/microsoft/phi-2
```
Then, chat with the model using the following command:
```bash
litgpt chat microsoft/phi-2
```
```
Now chatting with phi-2.
To exit, press 'Enter' on an empty prompt.
Seed set to 1234
>> Prompt: What is the main difference between a large language model and a traditional search engine?
>> Reply: A large language model uses deep learning algorithms to analyze and generate natural language, while a traditional search engine uses algorithms to retrieve information from web pages.
Time for inference: 1.14 sec total, 26.26 tokens/sec, 30 tokens
```
> [!TIP]
> Most model weights are already represented in an efficient bfloat16 format. However, if the model currently exceeds your GPU memory, you can try to pass the `--precision bf16-true` option. In addition, you can check the quantization documentation for further optimization, which is linked below.
**More information and additional resources**
- [tutorials/inference](inference.md): Chat and inference tutorial
- [tutorials/quantize](quantize.md): Quantizing models to reduce GPU memory requirements
## Using the LitGPT Python API for Inference
The previous section explained how to use the `litgpt chat` command line interface for inference. Alternatively, LitGPT also offers a Python API approach to generate text using an LLM:
```python
from litgpt import LLM
llm = LLM.load("microsoft/phi-2")
text = llm.generate("What do Llamas eat?", top_k=1, max_new_tokens=30)
print(text)
```
Note that the if you pass a supported model name to `LLM.load()`, as shown above, it will download the model from the HF hub if it doesn't exist locally, yet (use `litgpt download list` on the command line to get a list of all currently supported models.)
Alternatively, to load model from a local path, just provide the corresponding path as input to the `load` method:
```python
llm = LLM.load("path/to/my/local/checkpoint")
```
**More information and additional resources**
- [tutorials/python-api](python-api.md): The LitGPT Python API documentation
## Evaluating models
LitGPT comes with a handy `litgpt evaluate` command to evaluate models with [Eleuther AI's Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). For example, to evaluate the previously downloaded `microsoft/phi-2` model on several tasks available from the Evaluation Harness, you can use the following command:
```bash
litgpt evaluate microsoft/phi-2
--batch_size 16 \
--tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge"
```
(A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).)
## Deploy LLMs
You can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities:
```bash
# 1) Download a pretrained model (alternatively, use your own finetuned model)
litgpt download microsoft/phi-2
# 2) Start the server
litgpt serve microsoft/phi-2
```
```python
# 3) Use the server (in a separate session)
import requests, json
response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Example input"}
)
print(response.json()["output"])
```
This prints:
```
Instruct: Fix typos in the following sentence: Example input
Output: Example input.
```
**More information and additional resources**
- [tutorials/deploy](deploy.md): A full deployment tutorial and example
## Converting LitGPT model weights to `safetensors` format
Sometimes, it can be useful to convert LitGPT model weights for third-party and external tools. For example, we can convert a LitGPT model to the Hugging Face format and save it via `.safetensors` files, which we can do as follows:
```bash
litgpt convert_from_litgpt microsoft/phi-2 out/converted_model/
```
Certain tools like the `.from_pretrained` method in Hugging Face `transformers` also require the original `config.json` file that originally came with the downloaded model:
```bash
cp checkpoints/microsoft/phi-2/config.json out/converted_model/config.json
```
You can now load the model into a Hugging Face transformers model and safe it in a `.safetensors` format as follows:
```bash
import torch
from transformers import AutoModel
# Load model
state_dict = torch.load('out/converted_model/model.pth')
model = AutoModel.from_pretrained(
"microsoft/phi-2", state_dict=state_dict
)
# Save .safetensors files
model.save_pretrained("out/converted_model/")
```
```
⚡ ~/litgpt ls -lh out/converted_model
total 16G
-rwxr--r-- 1 sebastian sebastian 891 Mar 20 17:08 config.json
-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:08 model-00001-of-00003.safetensors
-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:09 model-00002-of-00003.safetensors
-rw-r--r-- 1 sebastian sebastian 601M Mar 20 17:09 model-00003-of-00003.safetensors
-rw-r--r-- 1 sebastian sebastian 5.2G Mar 20 16:30 model.pth
-rw-r--r-- 1 sebastian sebastian 33K Mar 20 17:09 model.safetensors.index.json
```
You can then use the model with external tools, for example, Eleuther AI's [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) (see the `lm_eval` installation instructions [here](https://github.com/EleutherAI/lm-evaluation-harness?tab=readme-ov-file#install)).
The LM Evaluation Harness requires a tokenizer to be present in the model checkpoint folder, which we can copy from the original download checkpoint:
```bash
# Copy the tokenizer needed by the Eval Harness
cp checkpoints/microsoft/phi-2/tokenizer*
out/converted_model
```
Then, we can run the Evaluation Harness as follows:
```bash
lm_eval --model hf \
--model_args pretrained="out/converted_model" \
--tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \
--device "cuda:0" \
--batch_size 4
```
> [!TIP]
> The Evaluation Harness tasks above are those used in Open LLM Leaderboard. You can find a list all supported tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).
**More information and additional resources**
- [tutorials/convert_lit_models](./convert_lit_models.md): Tutorial on converting LitGPT weights
## Get involved!
We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the [GitHub Issue](https://github.com/Lightning-AI/litgpt/issues) tracker.
We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.
> [!TIP]
> Unsure about contributing? Check out our [How to Contribute to LitGPT](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) guide.
If you have general questions about building with LitGPT, please [join our Discord](https://discord.gg/VptPCZkGNa).
================================================
FILE: tutorials/convert_hf_checkpoint.md
================================================
# Converting Hugging Face Transformers to LitGPT weights
By default, the `litgpt download` command converts the downloaded HF checkpoint files into a LitGPT compatible format after downloading. For example,
```bash
litgpt download EleutherAI/pythia-14m
```
creates the following files:
```
checkpoints/
└── EleutherAI/
└── pythia-14m/
├── config.json
├── generation_config.json
├── model_config.yaml # LitGPT specific file
├── lit_model.pth # LitGPT specific file
├── pytorch_model.bin
├── tokenizer.json
└── tokenizer_config.json
```
To disable the automatic conversion, which is useful for development and debugging purposes, you can run the `litgpt download` with the `--convert_checkpoint false` flag. This will only download the checkpoint files but do not convert them for use in LitGPT:
```bash
rm -rf checkpoints/EleutherAI/pythia-14m
litgpt download EleutherAI/pythia-14m \
--convert_checkpoint false
ls checkpoints/EleutherAI/pythia-14m
```
```
checkpoints/
└── EleutherAI/
└── pythia-14m/
├── config.json
├── generation_config.json
├── pytorch_model.bin
├── tokenizer.json
└── tokenizer_config.json
```
The required files `model_config.yaml` and `lit_model.pth` files can then be manually generated via the `litgpt/scripts/convert_hf_checkpoint.py` script:
```bash
litgpt convert_to_litgpt checkpoints/EleutherAI/pythia-14m
```
================================================
FILE: tutorials/convert_lit_models.md
================================================
## Converting LitGPT weights to Hugging Face Transformers
LitGPT weights need to be converted to a format that Hugging Face understands with a [conversion script](../litgpt/scripts/convert_lit_checkpoint.py) before our scripts can run.
We provide a helpful command to convert models LitGPT models back to their equivalent Hugging Face Transformers format:
```bash
litgpt convert_from_litgpt checkpoint_dir converted_dir
```
These paths are just placeholders, you will need to customize them based on which finetuning or pretraining command you ran and its configuration.
### Loading converted LitGPT checkpoints into transformers
For example,
```bash
cp checkpoints/repo_id/config.json converted/config.json
```
Then, you can load the checkpoint file in a Python session as follows:
```python
import torch
from transformers import AutoModel
state_dict = torch.load("output_dir/model.pth")
model = AutoModel.from_pretrained(
"output_dir/", local_files_only=True, state_dict=state_dict
)
```
Alternatively, you can also load the model without copying the `config.json` file as follows:
```python
model = AutoModel.from_pretrained("online_repo_id", state_dict=state_dict)
```
### Merging LoRA weights
Please note that if you want to convert a model that has been finetuned using an adapter like LoRA, these weights should be [merged](../litgpt/scripts/merge_lora.py) to the checkpoint prior to converting.
```sh
litgpt merge_lora path/to/lora/checkpoint_dir
```
# A finetuning and conversion tutorial
This section contains a reproducible example for finetuning a LitGPT model and converting it back into a HF `transformer` model.
1. Download a model of interest:
For convenience, we first specify an environment variable (optional) to avoid copy and pasting the whole path:
```bash
export repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
```
Instead of using TinyLlama, you can replace the `repo_id` target with any other model repository
specifier that is currently supported by LitGPT. You can get a list of supported repository specifier
by running `litgpt/scripts/download.py` without any additional arguments.
Then, we download the model we specified via `$repo_id` above:
```bash
litgpt download $repo_id
```
2. Finetune the model:
```bash
export finetuned_dir=out/lit-finetuned-model
litgpt finetune_lora $repo_id \
--out_dir $finetuned_dir \
--train.epochs 1 \
--data Alpaca
```
3. Merge LoRA weights:
Note that this step only applies if the model was finetuned with `lora.py` above and not when `full.py` was used for finetuning.
```bash
litgpt merge_lora $finetuned_dir/final
```
4. Convert the finetuning model back into a HF format:
```bash
litgpt convert_from_litgpt $finetuned_dir/final/ out/hf-tinyllama/converted
```
5. Load the model into a `transformers` model:
```python
import torch
from transformers import AutoModel
state_dict = torch.load('out/hf-tinyllama/converted/model.pth')
model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", state_dict=state_dict)
```
## Using the LM Evaluation Harness
To evaluate LitGPT models, use the integrated evaluation utilities based on Eleuther AI's LM Evaluation Harness. For more information, please see the [evaluation](evaluation.md) documentation.
Alternatively, if you wish to use converted LitGPT models with the LM Evaluation Harness from [Eleuther AI's GitHub repository](https://github.com/EleutherAI/lm-evaluation-harness), you can use the following steps.
1. Follow the instructions above to load the model into a Hugging Face transformers model.
2. Create a `model.safetensor` file:
```python
model.save_pretrained("out/hf-tinyllama/converted/")
```
3. Copy the tokenizer files into the model-containing directory:
```bash
cp checkpoints/$repo_id/tokenizer* out/hf-tinyllama/converted
```
4. Run the evaluation harness, for example:
```bash
lm_eval --model hf \
--model_args pretrained=out/hf-tinyllama/converted \
--tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \
--device "cuda:0" \
--batch_size 4
```
================================================
FILE: tutorials/deploy.md
================================================
# Serve and Deploy LLMs
This document shows how you can serve a LitGPT for deployment.
## Serve an LLM with LitServe
This section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable.
### Step 1: Start the inference server
```bash
# 1) Download a pretrained model (alternatively, use your own finetuned model)
litgpt download microsoft/phi-2
# 2) Start the server
litgpt serve microsoft/phi-2
```
> [!TIP]
> Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more.
### Step 2: Query the inference server
You can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows:
```python
import requests, json
response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Example input"}
)
print(response.json()["output"])
```
Executing the code above prints the following output:
```
Example input.
```
### Optional: Use the streaming mode
The 2-step procedure described above returns the complete response all at once. If you want to stream the response on a token-by-token basis, start the server with the streaming option enabled:
```bash
litgpt serve microsoft/phi-2 --stream true
```
Then, use the following updated code to query the inference server:
```python
import requests, json
response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Example input"},
stream=True
)
# stream the response
for line in response.iter_lines(decode_unicode=True):
if line:
print(json.loads(line)["output"], end="")
```
```
Sure, here is the corrected sentence:
Example input
```
## Serve an LLM with OpenAI-compatible API
LitGPT provides OpenAI-compatible endpoints that allow you to use the OpenAI SDK or any OpenAI-compatible client to interact with your models. This is useful for integrating LitGPT into existing applications that use the OpenAI API.
### Step 1: Start the server with OpenAI specification
```bash
# 1) Download a pretrained model (alternatively, use your own finetuned model)
litgpt download HuggingFaceTB/SmolLM2-135M-Instruct
# 2) Start the server with OpenAI-compatible endpoints
litgpt serve HuggingFaceTB/SmolLM2-135M-Instruct --openai_spec true
```
> [!TIP]
> The `--openai_spec true` flag enables OpenAI-compatible endpoints at `/v1/chat/completions` instead of the default `/predict` endpoint.
### Step 2: Query using OpenAI-compatible endpoints
You can now send requests to the OpenAI-compatible endpoint using curl:
```bash
curl -X POST http://127.0.0.1:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "SmolLM2-135M-Instruct",
"messages": [{"role": "user", "content": "Hello! How are you?"}]
}'
```
Or use the OpenAI Python SDK:
```python
from openai import OpenAI
# Configure the client to use your local LitGPT server
client = OpenAI(
base_url="http://127.0.0.1:8000/v1",
api_key="not-needed" # LitGPT doesn't require authentication by default
)
response = client.chat.completions.create(
model="SmolLM2-135M-Instruct",
messages=[
{"role": "user", "content": "Hello! How are you?"}
]
)
print(response.choices[0].message.content)
```
## Serve an LLM UI with Chainlit
If you are interested in developing a simple ChatGPT-like UI prototype, see the Chainlit tutorial in the following Studio:
================================================
FILE: tutorials/developer-docs/README.md
================================================
LitGPT developer documentation files.
================================================
FILE: tutorials/developer-docs/adding-models.md
================================================
# Adding New Models
This document provides an overview and explanation of how new LLM architectures and model weights can be added to LitGPT.
> [!NOTE]
> One of the design focus areas of LitGPT is to provide efficient readable code. At the same time, LitGPT aims to support selected LLMs that are useful to the community. LitGPT aims to reuse and share as much code as possible between different LLMs to strike a balance between code readability and enabling support for various LLMs. In short, we try to minimize writing custom code for a given LLM and aim for code reuse.
## 1. Discuss the LLM to be added
As an open-source project, we appreciate your contributions! However, before you begin putting valuable time and work into a contribution, ideally, open an issue to discuss whether support for a certain model is within the project's scope.
## 2. Set up your development environment
Clone the repository:
```bash
git clone https://github.com/Lightning-AI/litgpt.git
```
Then, install it with the "editable" mode for development:
```bash
cd litgpt
pip install litgpt -e ".[all]"
```
## 3. Update the config file
Update the [litgpt/config.py](../../litgpt/config.py) config file, adding the new model configuration there. It's easiest to start with the most similar model, copy the configuration, and then modify it according to the `config.json` file on the HF hub.
For example, suppose an entry for Llama 3 8B already exists and you want to add support for Llama 3 70B.
Copy the Llama 3 8B entry:
```python
# https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json
dict(
name="Llama-3-8B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"),
vocab_size=128256,
padding_multiple=64,
n_layer=32,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
),
```
Then create the entry for the 70B model. Here, make sure you update the values according to the `config.json` file available on the HF hub:
```python
# https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json
dict(
name="Llama-3-70B{}",
hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"),
vocab_size=128256,
padding_multiple=64,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
),
```
> [!NOTE]
> Some models may require you to implement a new MLP class analogous to `class LLaMAMLP`.
> A more or less reliable indicator is the presence of a `modeling.py` file in the model's original repository.
> If this file exists, it suggests that this model requires custom code.
> This will then also require additional changes beyond simply updating
> the configuration in LitGPT's `config.py`.
## 4. Try downloading the model
After making the modifications above, try downloading the model:
```bash
litgpt download meta-llama/Meta-Llama-3-70B --access_token ...
```
> [!NOTE]
> Not all models require an access token
If the conversion following the download fails, proceed with the next section.
## 5. Update the checkpoint conversion script
If the `litgpt download ...` command from the previous section failed, you may have to adjust the checkpoint conversion script: [litgpt/scripts/convert_hf_checkpoint.py](../../litgpt/scripts/convert_hf_checkpoint.py).
Here, you may have to adjust or implement a new `def copy_weights_hf_...` function.
You can test the updated conversion code without needing to redownload the weights as follows:
```bash
python litgpt/scripts/convert_hf_checkpoint.py meta-llama/Meta-Llama-3-70B
```
## 6. Add the Prompt Style
If you are adding a new model class, find out its prompt style. First, check [litgpt/prompts.py](../../litgpt/prompts.py) if a similar prompt style template already exists. For Llama 3, this is as follows:
```python
class Llama3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
# https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229
return (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful assistant.<|eot_id|>\n" # The system prompt is optional
"<|start_header_id|>user<|end_header_id|>\n\n"
f"{prompt}<|eot_id|>\n"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|eot_id|>")],
)
```
If your model requires a different prompt template, create a new `PromptStyle` class.
Then, in the same file, update the `prompt_styles` dictionary:
```python
prompt_styles: Dict[str, Type[PromptStyle]] = {
...
"llama3": Llama3,
}
```
Finally, also in the same file, update the `model_name_to_prompt_style` function:
```python
def model_name_to_prompt_style(model_name: str) -> PromptStyle:
...
if re.search("Llama-3.*-Instruct", model_name):
return Llama3()
```
## 7. Try using the model for inference
Next, use the model to see if inference works:
```bash
litgpt generate meta-llama/Meta-Llama-3-70B
```
> [!NOTE]
> If you notice that the model produces non-sensible language outputs, you need to double-check the config file and find out if there are incorrect values or other problems. The next section on adding unit tests may offer additional pointers.
## 8. Add unit tests
### 8.1 Add model unit tests
Open the [`tests/test_model.py`](../../tests/test_model.py) file and add a new `def test_against_hf_...` function using one of the existing functions as a template. For instance,
```python
def test_against_hf_llama2(ours_kwargs, device, dtype):
...
# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
```
If the
```bash
litgpt generate meta-llama/Meta-Llama-3-70B
```
command from the previous section produces incoherent text, this function can be a helpful guide for debugging. For this, modify the implementation in `transformers` and `litgpt` packages (on your local installation), to inspect or print out the intermediate values at a layer. It's recommend starting with the embedding layers and then go through one layer at the time, to find out where the values differ to get pointers for debugging.
Test the unit test via
```python
pytest tests/test_model.py::test_against_hf_...
```
### 8.2 Add prompt style unit test
Open the [`tests/test_model.py`](../../tests/test_model.py) file and add a test for the respective prompts you added earlier, if applicable. For example,
```python
def test_prompt_style_from_config():
model_names = [
...
"Llama-3-70B-Instruct",
...
]
```
Run the unit test via
```python
pytest tests/test_prompts.py
```
## 9. Try finetuning the model
Now, try finetuning the model:
```bash
litgpt finetune meta-llama/Meta-Llama-3-70B --train.max_steps 10
```
## 10. Update the documentation
Finally, update the documentation files.
### 10.1 Update the README file
Update the "All Models" table in the [README.md](../../README.md) file.
### 10.2 Update the download tutorials
Add the new model to the model table at the top as well as to the list under `litgpt download list`.
================================================
FILE: tutorials/developer-docs/python-api.md
================================================
# LitGPT High-level Python API
This is a work-in-progress draft for a high-level LitGPT Python API.
## Model loading & saving
The `LLM.load` command loads an `llm` object, which contains both the model object (a PyTorch module) and a preprocessor.
```python
from litgpt import LLM
llm = LLM.load(
model="url | local_path",
# high-level user only needs to care about those:
memory_reduction="none | medium | strong"
# advanced options for technical users:
source="hf | local | other"
quantize="bnb.nf4",
precision="bf16-true",
device=""auto | cuda | cpu",
)
```
Here,
- `llm.model` contains the PyTorch Module
- and `llm.preprocessor.tokenizer` contains the tokenizer
The `llm.save` command saves the model weights, tokenizer, and configuration information.
```python
llm.save(checkpoint_dir, format="lightning | ollama | hf")
```
## Inference / Chat
```
response = llm.generate(
prompt="What do Llamas eat?",
temperature=0.1,
top_p=0.8,
...
)
```
## Dataset
The `llm.prepare_dataset` command prepares a dataset for training.
```
llm.download_dataset(
URL,
...
)
```
```
dataset = llm.prepare_dataset(
path,
task="pretrain | instruction_finetune",
test_portion=0.1,
...
)
```
## Training
```python
llm.instruction_finetune(
config=None,
dataset=dataset,
max_iter=10,
method="full | lora | adapter | adapter_v2"
)
```
```python
llm.pretrain(config=None, dataset=dataset, max_iter=10, ...)
```
## Serving
```python
llm.serve(port=8000)
```
Then in another Python session:
```python
import requests, json
response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Example input"}
)
print(response.json()["output"])
```
================================================
FILE: tutorials/download_model_weights.md
================================================
# Download Model Weights with LitGPT
LitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the `litgpt download list` command.
| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) |
| Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786)
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) |
| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) |
| Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743) |
| Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233) |
| Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) |
| Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/) |
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
| Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
| R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | |
## General Instructions
### 1. List Available Models
To see all supported models, run the following command:
```bash
litgpt download list
```
The output is shown below:
```
allenai/OLMo-1B-hf
allenai/OLMo-7B-hf
allenai/OLMo-7B-Instruct-hf
bsc-lt/salamandra-2b
bsc-lt/salamandra-2b-instruct
bsc-lt/salamandra-7b
bsc-lt/salamandra-7b-instruct
codellama/CodeLlama-13b-hf
codellama/CodeLlama-13b-Instruct-hf
codellama/CodeLlama-13b-Python-hf
codellama/CodeLlama-34b-hf
codellama/CodeLlama-34b-Instruct-hf
codellama/CodeLlama-34b-Python-hf
codellama/CodeLlama-70b-hf
codellama/CodeLlama-70b-Instruct-hf
codellama/CodeLlama-70b-Python-hf
codellama/CodeLlama-7b-hf
codellama/CodeLlama-7b-Instruct-hf
codellama/CodeLlama-7b-Python-hf
databricks/dolly-v2-12b
databricks/dolly-v2-3b
databricks/dolly-v2-7b
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
deepseek-ai/DeepSeek-R1-Distill-Llama-70B
EleutherAI/pythia-1.4b
EleutherAI/pythia-1.4b-deduped
EleutherAI/pythia-12b
EleutherAI/pythia-12b-deduped
EleutherAI/pythia-14m
EleutherAI/pythia-160m
EleutherAI/pythia-160m-deduped
EleutherAI/pythia-1b
EleutherAI/pythia-1b-deduped
EleutherAI/pythia-2.8b
EleutherAI/pythia-2.8b-deduped
EleutherAI/pythia-31m
EleutherAI/pythia-410m
EleutherAI/pythia-410m-deduped
EleutherAI/pythia-6.9b
EleutherAI/pythia-6.9b-deduped
EleutherAI/pythia-70m
EleutherAI/pythia-70m-deduped
garage-bAInd/Camel-Platypus2-13B
garage-bAInd/Camel-Platypus2-70B
garage-bAInd/Platypus-30B
garage-bAInd/Platypus2-13B
garage-bAInd/Platypus2-70B
garage-bAInd/Platypus2-70B-instruct
garage-bAInd/Platypus2-7B
garage-bAInd/Stable-Platypus2-13B
google/codegemma-7b-it
google/gemma-3-27b-it
google/gemma-3-12b-it
google/gemma-3-4b-it
google/gemma-3-1b-it
google/gemma-2-27b
google/gemma-2-27b-it
google/gemma-2-2b
google/gemma-2-2b-it
google/gemma-2-9b
google/gemma-2-9b-it
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
h2oai/h2o-danube2-1.8b-chat
HuggingFaceTB/SmolLM2-135M
HuggingFaceTB/SmolLM2-135M-Instruct
HuggingFaceTB/SmolLM2-360M
HuggingFaceTB/SmolLM2-360M-Instruct
HuggingFaceTB/SmolLM2-1.7B
HuggingFaceTB/SmolLM2-1.7B-Instruct
lmsys/longchat-13b-16k
lmsys/longchat-7b-16k
lmsys/vicuna-13b-v1.3
lmsys/vicuna-13b-v1.5
lmsys/vicuna-13b-v1.5-16k
lmsys/vicuna-33b-v1.3
lmsys/vicuna-7b-v1.3
lmsys/vicuna-7b-v1.5
lmsys/vicuna-7b-v1.5-16k
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-3.2-1B
meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B
meta-llama/Llama-3.2-3B-Instruct
meta-llama/Llama-3.3-70B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3.1-405B
meta-llama/Meta-Llama-3.1-405B-Instruct
meta-llama/Meta-Llama-3.1-70B
meta-llama/Meta-Llama-3.1-70B-Instruct
meta-llama/Meta-Llama-3.1-8B
meta-llama/Meta-Llama-3.1-8B-Instruct
microsoft/phi-1_5
microsoft/phi-2
microsoft/Phi-3-mini-128k-instruct
microsoft/Phi-3-mini-4k-instruct
microsoft/Phi-3.5-mini-instruct
microsoft/phi-4
microsoft/Phi-4-mini-instruct
mistralai/mathstral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistralai/Mistral-7B-Instruct-v0.3
mistralai/Mistral-7B-v0.1
mistralai/Mistral-7B-v0.3
mistralai/Mistral-Large-Instruct-2407
mistralai/Mistral-Large-Instruct-2411
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x22B-Instruct-v0.1
mistralai/Mixtral-8x22B-v0.1
NousResearch/Nous-Hermes-13b
NousResearch/Nous-Hermes-llama-2-7b
NousResearch/Nous-Hermes-Llama2-13b
nvidia/Llama-3.1-Nemotron-70B-Instruct-HF
openlm-research/open_llama_13b
openlm-research/open_llama_3b
openlm-research/open_llama_7b
Qwen/Qwen2.5-0.5B
Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-3B
Qwen/Qwen2.5-3B-Instruct
Qwen/Qwen2.5-7B
Qwen/Qwen2.5-7B-Instruct
Qwen/Qwen2.5-7B-Instruct-1M
Qwen/Qwen2.5-14B
Qwen/Qwen2.5-14B-Instruct
Qwen/Qwen2.5-14B-Instruct-1M
Qwen/Qwen2.5-32B
Qwen/Qwen2.5-32B-Instruct
Qwen/Qwen2.5-72B
Qwen/Qwen2.5-72B-Instruct
Qwen/Qwen2.5-Coder-0.5B
Qwen/Qwen2.5-Coder-0.5B-Instruct
Qwen/Qwen2.5-Coder-1.5B
Qwen/Qwen2.5-Coder-1.5B-Instruct
Qwen/Qwen2.5-Coder-3B
Qwen/Qwen2.5-Coder-3B-Instruct
Qwen/Qwen2.5-Coder-7B
Qwen/Qwen2.5-Coder-7B-Instruct
Qwen/Qwen2.5-Coder-14B
Qwen/Qwen2.5-Coder-14B-Instruct
Qwen/Qwen2.5-Coder-32B
Qwen/Qwen2.5-Coder-32B-Instruct
Qwen/Qwen2.5-Math-1.5B
Qwen/Qwen2.5-Math-1.5B-Instruct
Qwen/Qwen2.5-Math-7B
Qwen/Qwen2.5-Math-7B-Instruct
Qwen/Qwen2.5-Math-72B
Qwen/Qwen2.5-Math-72B-Instruct
Qwen/Qwen3-0.6B
Qwen/Qwen3-0.6B-Base
Qwen/Qwen3-1.7B
Qwen/Qwen3-1.7B-Base
Qwen/Qwen3-4B
Qwen/Qwen3-4B-Base
Qwen/Qwen3-8B
Qwen/Qwen3-8B-Base
Qwen/Qwen3-14B
Qwen/Qwen3-14B-Base
Qwen/Qwen3-32B
Qwen/Qwen3-30B-A3B
Qwen/Qwen3-30B-A3B-Base
Qwen/Qwen3-235B-A22B
Qwen/Qwen3-4B-Thinking-2507
Qwen/Qwen3-4B-Instruct-2507
Qwen/Qwen3-30B-A3B-Thinking-2507
Qwen/Qwen3-30B-A3B-Instruct-2507
Qwen/Qwen3-235B-A22B-Thinking-2507
Qwen/Qwen3-235B-A22B-Instruct-2507
Qwen/QwQ-32B
Qwen/QwQ-32B-Preview
stabilityai/FreeWilly2
stabilityai/stable-code-3b
stabilityai/stablecode-completion-alpha-3b
stabilityai/stablecode-completion-alpha-3b-4k
stabilityai/stablecode-instruct-alpha-3b
stabilityai/stablelm-3b-4e1t
stabilityai/stablelm-base-alpha-3b
stabilityai/stablelm-base-alpha-7b
stabilityai/stablelm-tuned-alpha-3b
stabilityai/stablelm-tuned-alpha-7b
stabilityai/stablelm-zephyr-3b
tiiuae/falcon-180B
tiiuae/falcon-180B-chat
tiiuae/falcon-40b
tiiuae/falcon-40b-instruct
tiiuae/falcon-7b
tiiuae/falcon-7b-instruct
tiiuae/Falcon3-1B-Base
tiiuae/Falcon3-1B-Instruct
tiiuae/Falcon3-3B-Base
tiiuae/Falcon3-3B-Instruct
tiiuae/Falcon3-7B-Base
tiiuae/Falcon3-7B-Instruct
tiiuae/Falcon3-10B-Base
tiiuae/Falcon3-10B-Instruct
TinyLlama/TinyLlama-1.1B-Chat-v1.0
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
togethercomputer/LLaMA-2-7B-32K
togethercomputer/RedPajama-INCITE-7B-Base
togethercomputer/RedPajama-INCITE-7B-Chat
togethercomputer/RedPajama-INCITE-7B-Instruct
togethercomputer/RedPajama-INCITE-Base-3B-v1
togethercomputer/RedPajama-INCITE-Base-7B-v0.1
togethercomputer/RedPajama-INCITE-Chat-3B-v1
togethercomputer/RedPajama-INCITE-Chat-7B-v0.1
togethercomputer/RedPajama-INCITE-Instruct-3B-v1
togethercomputer/RedPajama-INCITE-Instruct-7B-v0.1
Trelis/Llama-2-7b-chat-hf-function-calling-v2
unsloth/Mistral-7B-v0.2
```
> [!TIP]
> To sort the list above by model name after the `/`, use `litgpt download list | sort -f -t'/' -k2`.
> [!NOTE]
> If you want to adopt a model variant that is not listed in the table above but has a similar architecture as one of the supported models, you can use this model by by using the `--model_name` argument as shown below:
>
> ```bash
> litgpt download NousResearch/Hermes-2-Pro-Mistral-7B \
> --model_name Mistral-7B-v0.1
> ```
### 2. Download Model Weights
To download the weights for a specific model provide a `` with the model's repository ID. For example:
```bash
litgpt download
```
This command downloads the model checkpoint into the `checkpoints/` directory.
### 3. Additional Help
For more options, add the `--help` flag when running the script:
```bash
litgpt download --help
```
### 4. Run the Model
After conversion, run the model with the given checkpoint path as input, adjusting `repo_id` accordingly:
```bash
litgpt chat
```
## Tinyllama Example
This section shows a typical end-to-end example for downloading and using TinyLlama:
1. List available TinyLlama checkpoints:
```bash
litgpt download list | grep Tiny
```
```
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
TinyLlama/TinyLlama-1.1B-Chat-v1.0
```
2. Download a TinyLlama checkpoint:
```bash
export repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
litgpt download $repo_id
```
3. Use the TinyLlama model:
```bash
litgpt chat $repo_id
```
## Specific models and access tokens
Note that certain models require that you've been granted access to the weights on the Hugging Face Hub.
For example, to get access to the Gemma 2B model, you can do so by following the steps at . After access is granted, you can find your HF hub token in .
Once you've been granted access and obtained the access token you need to pass the additional `--access_token`:
```bash
litgpt download google/gemma-2b \
--access_token your_hf_token
```
## Finetunes and Other Model Variants
Sometimes you want to download the weights of a finetune of one of the models listed above. To do this, you need to manually specify the `model_name` associated to the config to use. For example:
```bash
litgpt download NousResearch/Hermes-2-Pro-Mistral-7B \
--model_name Mistral-7B-v0.1
```
## Tips for GPU Memory Limitations
The `litgpt download` command will automatically convert the downloaded model checkpoint into a LitGPT-compatible format. In case this conversion fails due to GPU memory constraints, you can try to reduce the memory requirements by passing the `--dtype bf16-true` flag to convert all parameters into this smaller precision (however, note that most model weights are already in a bfloat16 format, so it may not have any effect):
```bash
litgpt download
--dtype bf16-true
```
(If your GPU does not support the bfloat16 format, you can also try a regular 16-bit float format via `--dtype 16-true`.)
## Converting Checkpoints Manually
For development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps.
You can do this by passing the `--convert_checkpoint false` option to the download script:
```bash
litgpt download \
--convert_checkpoint false
```
and then calling the `convert_hf_checkpoint` command:
```bash
litgpt convert_to_litgpt
```
## Downloading Tokenizers Only
In some cases we don't need the model weight, for example, when we are pretraining a model from scratch instead of finetuning it. For cases like this, you can use the `--tokenizer_only` flag to only download a model's tokenizer, which can then be used in the pretraining scripts:
```bash
litgpt download TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
--tokenizer_only true
```
and
```bash
litgpt pretrain tiny-llama-1.1b \
--data ... \
--tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/
```
================================================
FILE: tutorials/evaluation.md
================================================
# LLM Evaluation
## Using lm-evaluation-harness
You can evaluate LitGPT using [EleutherAI's lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) framework with a large number of different evaluation tasks.
You need to install the `lm-eval` framework first:
```bash
pip install lm_eval
```
### Evaluating LitGPT base models
Suppose you downloaded a base model that we want to evaluate. Here, we use the `microsoft/phi-2` model:
```bash
litgpt download microsoft/phi-2
```
The download command above will save the model to the `checkpoints/microsoft/phi-2` directory, which we can
specify in the following evaluation command:
```
litgpt evaluate microsoft/phi-2/ \
--batch_size 4 \
--tasks "hellaswag,truthfulqa_mc2,mmlu" \
--out_dir evaluate_model/
```
The resulting output is as follows:
```
...
|---------------------------------------|-------|------|-----:|--------|-----:|---|-----:|
...
|truthfulqa_mc2 | 2|none | 0|acc |0.4656|± |0.0164|
|hellaswag | 1|none | 0|acc |0.2569|± |0.0044|
| | |none | 0|acc_norm|0.2632|± |0.0044|
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.2434|± |0.0036|
| - humanities |N/A |none | 0|acc |0.2578|± |0.0064|
| - other |N/A |none | 0|acc |0.2401|± |0.0077|
| - social_sciences|N/A |none | 0|acc |0.2301|± |0.0076|
| - stem |N/A |none | 0|acc |0.2382|± |0.0076|
```
Please note that the `litgpt evaluate` command run an internal model conversion.
This is only necessary the first time you want to evaluate a model, and it will skip the
conversion steps if you run the `litgpt evaluate` on the same checkpoint directory again.
In some cases, for example, if you modified the model in the `checkpoint_dir` since the first `litgpt evaluate`
call, you need to use the `--force_conversion` flag to to update the files used by litgpt evaluate accordingly:
```
litgpt evaluate microsoft/phi-2/ \
--batch_size 4 \
--out_dir evaluate_model/ \
--tasks "hellaswag,truthfulqa_mc2,mmlu" \
--force_conversion true
```
> [!TIP]
> Run `litgpt evaluate list` to print a list
> of the supported tasks. To filter for a specific subset of tasks, e.g., MMLU, use `litgpt evaluate list | grep mmlu`.
> [!TIP]
> The evaluation may take a long time, and for testing purpoes, you may want to reduce the number of tasks
> or set a limit for the number of examples per task, for example, `--limit 10`.
### Evaluating LoRA-finetuned LLMs
No further conversion is necessary when evaluating LoRA-finetuned models as the `finetune_lora` command already prepares the necessary merged model files:
```bash
litgpt finetune_lora microsoft/phi-2 \
--out_dir lora_model
```
```bash
litgpt evaluate lora_model/final \
--batch_size 4 \
--tasks "hellaswag,truthfulqa_mc2,mmlu" \
--out_dir evaluate_model/ \
```
### Evaluating on a custom test set
There is currently no built-in function to evaluate models on custom test sets. However, this section describes a general approach that users can take to evaluate the responses of a model using another LLM.
Suppose you have a test dataset with the following structure:
```python
test_data = [
{
"instruction": "Name the author of 'Pride and Prejudice'.",
"input": "",
"output": "Jane Austen."
},
{
"instruction": "Pick out the adjective from the following list.",
"input": "run, tall, quickly",
"output": "The correct adjective from the list is 'tall.'"
},
]
```
For simplicity, the dictionary above only contains two entries. In practice, it is recommended to use test datasets that contain at least 100 entries (ideally 1000 or more).
If your dataset is stored in JSON format, use the following code to load it:
```python
with open("test_data.json", "r") as file:
test_data = json.load(file)
```
Next, it is recommended to format the dataset according to a prompt style. For example, to use the `Alpaca` prompt style, use the following code:
```python
from litgpt.prompts import Alpaca
prompt_style = Alpaca()
prompt_style.apply(prompt=test_data[0]["instruction"], **test_data[0])
```
which returns
```
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nName the author of 'Pride and Prejudice'.\n\n### Response:\n
```
Next, load the LLM you want to evaluate. For this example, we use `phi-2`:
```python
from litgpt import LLM
llm = LLM.load("microsoft/phi-2")
```
Then, using the loaded model, we add the test set responses to the dataset:
```python
from tqdm import trange
for i in trange(len(test_data)):
response = llm.generate(prompt_style.apply(prompt=test_data[i]["instruction"], **test_data[i]))
test_data[i]["response"] = response
```
Next, we use a second LLM to calculate the response quality on a scale from 0 to 100. It is recommended to use the 70B Llama 3 instruction-fintuned model for this task, or the smaller 8B Llama 3 model, which is more resource-efficient:
```python
del llm # delete previous `llm` to free up GPU memory
scorer = LLM.load("meta-llama/Meta-Llama-3-8B-Instruct", access_token="...")
```
Then, based on this LLM, we calculate the response quality with the following function:
```python
from tqdm import tqdm
def generate_model_scores(data_dict, model, response_field="response", target_field="output"):
scores = []
for entry in tqdm(data_dict, desc="Scoring entries"):
prompt = (
f"Given the input `{format_input(entry)}` "
f"and correct output `{entry[target_field]}`, "
f"score the model response `{entry[response_field]}`"
f" on a scale from 0 to 100, where 100 is the best score. "
f"Respond with the integer number only."
)
score = model.generate(prompt, max_new_tokens=50)
try:
scores.append(int(score))
except ValueError:
continue
return scores
```
```python
scores = generate_model_scores(test_data, model=scorer)
print(f"\n{llm}")
print(f"Number of scores: {len(scores)} of {len(test_data)}")
print(f"Average score: {sum(scores)/len(scores):.2f}\n")
```
This will print out the average score on all test set entries:
```
Scoring entries: 100%|██████████| 2/2 [00:00<00:00, 4.37it/s]
Number of scores: 2 of 2
Average score: 47.50
```
================================================
FILE: tutorials/examples/ptl-trainer/README.md
================================================
## Minimal PyTorch Lightning Trainer Example
The script in this folder provides minimal examples showing how to train a LitGPT model using LitGPT's `GPT` class with the [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) Trainer.
You can run the scripts as follows:
 
## Small 160M model:
```bash
# Download the Pythia model
litgpt download EleutherAI/pythia-160m
python litgpt_ptl_small.py
```
 
## Medium-sized 8B model:
```bash
# Download the Llama 3.1 model
litgpt download meta-llama/Meta-Llama-3.1-8B --access_token hf_...
python litgpt_ptl_medium.py
```
================================================
FILE: tutorials/examples/ptl-trainer/litgpt_ptl_medium.py
================================================
import lightning as L
import torch
import litgpt
from litgpt.data import Alpaca2k
from litgpt.lora import GPT, merge_lora_weights
class LitLLM(L.LightningModule):
def __init__(self):
super().__init__()
self.model = GPT.from_name(
name="Llama-3.1-8B",
lora_r=32,
lora_alpha=16,
lora_dropout=0.05,
lora_key=False,
lora_value=True,
)
litgpt.lora.mark_only_lora_as_trainable(self.model)
def on_train_start(self):
state_dict = torch.load("checkpoints/meta-llama/Meta-Llama-3.1-8B/lit_model.pth", mmap=True)
self.model.load_state_dict(state_dict, strict=False)
def training_step(self, batch):
input_ids, targets = batch["input_ids"], batch["labels"]
logits = self.model(input_ids)
loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
warmup_steps = 10
optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
return [optimizer], [scheduler]
if __name__ == "__main__":
data = Alpaca2k()
tokenizer = litgpt.Tokenizer("checkpoints/meta-llama/Meta-Llama-3.1-8B")
data.connect(tokenizer, batch_size=1, max_seq_length=512)
trainer = L.Trainer(
devices=1,
max_epochs=2,
accumulate_grad_batches=8,
precision="bf16-true",
)
with trainer.init_module(empty_init=True):
model = LitLLM()
trainer.fit(model, data)
# Save final checkpoint
merge_lora_weights(model.model)
trainer.save_checkpoint("checkpoints/finetuned.ckpt", weights_only=True)
================================================
FILE: tutorials/examples/ptl-trainer/litgpt_ptl_small.py
================================================
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import lightning as L
import torch
from litgpt import LLM
from litgpt.data import Alpaca2k
class LitLLM(L.LightningModule):
def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):
super().__init__()
self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)
self.trainer_ckpt_path = trainer_ckpt_path
def setup(self, stage):
self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)
def training_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("validation_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
warmup_steps = 10
optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
return [optimizer], [scheduler]
if __name__ == "__main__":
batch_size = 8
accumulate_grad_batches = 1
#########################################################
# Use case 1: Pretraining from random weights
#########################################################
llm = LLM.load("EleutherAI/pythia-160m", tokenizer_dir="EleutherAI/pythia-160m", init="random")
llm.save("pythia-160m-random-weights")
del llm
lit_model = LitLLM(checkpoint_dir="pythia-160m-random-weights", tokenizer_dir="EleutherAI/pythia-160m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
del lit_model
#############################################################################
# Use case 2: Continued pretraining / finetuning from downloaded checkpoint
#############################################################################
lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
del lit_model
#########################################################
# Use case 3: Resume training from Trainer checkpoint
#########################################################
import os
def find_latest_checkpoint(directory):
latest_checkpoint = None
latest_time = 0
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".ckpt"):
file_path = os.path.join(root, file)
file_time = os.path.getmtime(file_path)
if file_time > latest_time:
latest_time = file_time
latest_checkpoint = file_path
return latest_checkpoint
lit_model = LitLLM(
checkpoint_dir="EleutherAI/pythia-160m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs")
)
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
#################################################################
# Use case 4: Resume training after saving a checkpoint manually
#################################################################
lit_model.llm.save("finetuned_checkpoint")
del lit_model
lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint")
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
================================================
FILE: tutorials/finetune.md
================================================
# Finetuning
We provide a simple finetuning commands (`litgpt finetune_*`) that instruction-finetune a pretrained model on datasets such as [Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [Dolly](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm), and others. For more information on the supported instruction datasets and how to prepare your own custom datasets, please see the [tutorials/prepare_dataset](prepare_dataset.md) tutorials.
LitGPT currently supports the following finetuning methods:
```bash
litgpt finetune_full
litgpt finetune_lora
litgpt finetune_adapter
litgpt finetune_adapter_v2
```
> [!TIP]
> To install all required dependencies before finetuning, first run `pip install "litgpt[all]"`.
The following section provides more details about these methods, including links for additional resources.
## LitGPT finetuning commands
The section below provides additional information on the available and links to further resources.
### Full finetuning
```bash
litgpt finetune_full
```
This method trains all model weight parameters and is the most memory-intensive finetuning technique in LitGPT.
**More information and resources:**
- the LitGPT [tutorials/finetune_full](finetune_full.md) tutorial
### LoRA and QLoRA finetuning
```bash
litgpt finetune_lora stabilityai/stablelm-base-alpha-3b
```
LoRA and QLoRA are parameter-efficient finetuning technique that only require updating a small number of parameters, which makes this a more memory-efficienty alternative to full finetuning.
**More information and resources:**
- the LitGPT [tutorials/finetune_lora](finetune_lora.md) tutorial
- the LoRA paper by ([Hu et al. 2021](https://arxiv.org/abs/2106.09685))
- the conceptual tutorial [Parameter-Efficient LLM Finetuning With Low-Rank Adaptation (LoRA)](https://lightning.ai/pages/community/tutorial/lora-llm/)
### Adapter finetuning
```bash
litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b
```
or
```bash
litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b
```
Similar to LoRA, adapter finetuning is a parameter-efficient finetuning technique that only requires training a small subset of weight parameters, making this finetuning method more memory-efficient than full-parameter finetuning.
**More information and resources:**
- the LitGPT [tutorials/finetune_adapter](finetune_adapter.md) tutorial
- the Llama-Adapter ([Gao et al. 2023](https://arxiv.org/abs/2304.15010)) and Llama-Adapter v2 ([Zhang et al. 2023](https://arxiv.org/abs/2303.16199)) papers that originally introduces these methods
- the conceptual tutorial [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/)
================================================
FILE: tutorials/finetune_adapter.md
================================================
# Finetuning with Adapter
Adapter, first introduced for the LLaMA model as [LLaMA-Adapter](https://arxiv.org/abs/2303.16199), is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in an LLM. In total, there are only ~500k parameters to update during finetuning in StableLM 3B, which significantly reduces the memory footprint and speeds up training.
We are able to demonstrate instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3060 GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
If you are new to Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.
LLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the StableLM 3B model has 500k trainable parameters with GPT v1, GPT-Adapter v2 adds an additional 1.5 M trainable parameter for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~2.3 M trainable parameters in total.
## Preparation
The steps here only need to be done once:
1. Follow the instructions in the [README](../README.md) to install the dependencies.
2. Download and convert the weights following our [guide](download_model_weights.md).
LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.
You can optionally [prepare your own dataset](#tune-on-your-dataset).
For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.
## Running the finetuning
```bash
litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \
--data Alpaca \
```
or for Adapter V2
```bash
litgpt finetune adapter_v2 stabilityai/stablelm-base-alpha-3b \
--data Alpaca \
```
The finetuning requires at least one GPU with ~12 GB memory.
You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
To fit Adapter V2 to 12GB memory set `--train.micro_batch_size 2`.
For example, the following settings will let you finetune the model in under 1 hour:
```bash
--devices 4 --train.micro_batch_size 4
```
This script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name:
```bash
litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \
--data Alpaca \
--out_dir out/adapter/my-model-finetuned
```
or for Adapter V2
```bash
litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \
--data Alpaca \
--out_dir out/adapter_v2/my-model-finetuned
```
If your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument.
For instance, to fine-tune on MPS (the GPU on modern Macs), you can run
```bash
litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \
--data Alpaca \
--out_dir out/adapter/my-model-finetuned \
--precision 32-true
```
Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac.
### Quantization
Optionally, finetuning using quantization can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:
```bash
litgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \
--quantize "bnb.nf4"
```
or using `adapter_v2` with double-quantization:
```bash
litgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \
--quantize "bnb.nf4-dq"
```
For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).
## Test the model
You can test the finetuned model with your own instructions by running:
```bash
litgpt generate_adapter stabilityai/stablelm-base-alpha-3b \
--prompt "Recommend a movie to watch on the weekend."
```
or for Adapter V2
```bash
litgpt generate_adapter_v2 stabilityai/stablelm-base-alpha-3b \
--prompt "Recommend a movie to watch on the weekend."
```
Output:
```text
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
```
If your GPU supports `bfloat16`, the script will automatically use it.
## Tune on your dataset
You can easily train on your own instruction dataset saved in JSON format.
1. Create a JSON file in which each row holds one instruction-response pair.
A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.
Below is an example json file:
```text
[
{
"instruction": "Arrange the given numbers in ascending order.",
"input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template
"output": "0, 2, 3, 4, 8"
},
...
]
```
2. Run `litgpt adapter` or `litgpt adapter_v2` by passing in the location of your data (and optionally other parameters):
```bash
litgpt finetune_adapter tiiuae/falcon-7b \
--data JSON \
--data.json_path data/mydata.json \
--out_dir data/mydata-finetuned
```
================================================
FILE: tutorials/finetune_full.md
================================================
# Finetuning the whole model
If you are interested in parameter-efficient finetuning, check out [finetune_adapter.md](finetune_adapter.md). In contrast to parameter-efficient finetuning, this "full" approach finetunes all model parameters, which is substantially more expensive. It may only be recommended as a baseline for comparison studies.
## Preparation
The steps here only need to be done once:
1. Follow the instructions in the [README](../README.md) to install the dependencies.
2. Download and convert the weights following our [guide](download_model_weights.md).
LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.
You can optionally [prepare your own dataset](#tune-on-your-dataset).
For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.
## Running the finetuning
```bash
litgpt finetune_full tiiuae/falcon-7b \
--data Alpaca \
```
Finetuning the falcon-7b model requires at least 8 GPUs with ~40 GB memory each.
You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
This script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name:
```bash
litgpt finetune_full tiiuae/falcon-7b \
--data Alpaca \
--out_dir out/full/my-model-finetuned
```
If your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument.
For instance, to fine-tune on MPS (the GPU on modern Macs), you can run
```bash
litgpt finetune_full tiiuae/falcon-7b \
--data Alpaca \
--out_dir out/full/my-model-finetuned \
--precision 32-true
```
Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac.
## Test the model
You can test the finetuned model with your own instructions by running:
```bash
litgpt generate tiiuae/falcon-7b \
--prompt "Recommend a movie to watch on the weekend." \
--finetuned_path out/full/my-model-finetuned/lit_model_finetuned.pth
```
Output:
```text
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
```
If your GPU supports `bfloat16`, the script will automatically use it.
## Tune on your dataset
You can easily train on your own instruction dataset saved in JSON format.
1. Create a JSON file in which each row holds one instruction-response pair.
A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.
Below is an example json file:
```text
[
{
"instruction": "Arrange the given numbers in ascending order.",
"input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template
"output": "0, 2, 3, 4, 8"
},
...
]
```
2. Run `litgpt finetune` by passing in the location of your data (and optionally other parameters):
```bash
litgpt finetune tiiuae/falcon-7b \
--data JSON \
--data.json_path data/mydata.json \
--out_dir data/mydata-finetuned
```
================================================
FILE: tutorials/finetune_lora.md
================================================
# Finetuning with LoRA / QLoRA
[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.
We demonstrate this method by instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU** with CUDA 11.8.
## Preparation
The steps here only need to be done once:
1. Follow the instructions in the [README](../README.md) to install the dependencies.
2. Download and convert the weights and save them in the `./checkpoints` folder.
Weights can be downloaded following the instructions in the [download_model_weights](download_model_weights.md) documentation:
LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.
You can optionally [prepare your own dataset](#tune-on-your-dataset).
For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.
## Running the Finetuning
```bash
litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \
--data Alpaca
```
The finetuning requires at least one GPU with ~24 GB memory (RTX 3090).
This script will save checkpoints periodically to the folder `out/`.
> [!NOTE]
> LoRA can be applied to not only `query`, `key` or `value` matrices, but also to `projection`, `mlp` and classification `head`.
> According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): "LoRA on all linear transformer block layers are required to match full finetuning performance".
> By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the arguments to `litgpt/finetune/lora.py` accordingly.
Optionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:
```bash
litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \
--quantize "bnb.nf4"
```
and optionally with double-quantization:
```bash
litgpt finetune_lora stabilityai/stablelm-base-alpha-3b \
--quantize "bnb.nf4-dq"
```
The table below lists a comparison with different settings on a StableLM 3B model finetuned with LoRA on Alpaca for 1,000 iterations using a microbatch size of 1:
| Settings | Training Memory | Training Time | Inference Memory |
|---------------------------------------------|-----------------|---------------|-------------------|
| Default (bf16-mixed) | 26.92 GB | 1.34 min | 21.43 GB |
| --precision bf16-true | 9.69 GB | 1.24 min | 7.30 GB |
| --precision bf16-true --quantize bnb.nf4 | 6.35 GB | 1.82 min | 3.20 GB |
| --precision bf16-true --quantize bnb.nf4-dq | 6.19 GB | 1.87 min | 3.04 GB |
The advantages of QLoRA-style quantization are more pronounced in larger models, such as Llama 2 7B. The table below summarizes the results for Llama 2 7B on Alpaca for 1,000 iterations using a microbatch size of 1:
| Settings | Training Memory | Training Time | Inference Memory |
|---------------------------------------------|------------------|---------------|------------------|
| Default (bf16-mixed) | OutOfMemoryError | N/A | 40.21 GB |
| --precision bf16-true | 21.30 GB | 2.36 min | 13.52 GB |
| --precision bf16-true --quantize bnb.nf4 | 14.14 GB | 3.68 min | 4.57 GB |
| --precision bf16-true --quantize bnb.nf4-dq | 13.84 GB | 3.83 min | 4.26 GB |
For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).
## Test the Model
You can test the finetuned model with your own instructions by running:
```bash
litgpt generate "out/lora/final" \
--prompt "Recommend a movie to watch on the weekend."
```
Output:
```text
I would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...
```
If your GPU supports `bfloat16`, you can additionally pass `--precision "bf16-true"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2 GB for `--precision "32-full"`). In addition, you may use quantization methods, for example `--precision "bf16-true" --quantize "bnb.nf4"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.
## Tune on Your Dataset
You can easily train on your own instruction dataset saved in JSON format.
1. Create a JSON file in which each row holds one instruction-response pair.
A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.
Below is an example json file:
```text
[
{
"instruction": "Arrange the given numbers in ascending order.",
"input": "2, 4, 0, 8, 3", // Optional: only used in Alpaca chat template
"output": "0, 2, 3, 4, 8"
},
...
]
```
2. Run `litgpt finetune_lora` by passing in the location of your data (and optionally other parameters):
```bash
litgpt finetune_lora checkpoints/stabilityai/stablelm-base-alpha-3b \
--data JSON \
--data.json_path data/mydata.json \
--out_dir out_dir/mydata-finetuned
```
3. Test and use the finetuned model:
```bash
litgpt chat out_dir/mydata-finetuned/final
```
or
```bash
litgpt serve out_dir/mydata-finetuned/final
```
## Merging LoRA Weights (Optional)
Finetuning a model with LoRA generates a `lit_model.pth.lora` file.
This file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space.
> [!NOTE]
> LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`.
> Manual merging is only necessary if you want to use the checkpoint outside LitGPT.
If desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint.
The advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.
For example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows:
```bash
litgpt merge_lora "out/lora/step-002000"
```
The command above creates a full `lit_model.pth` checkpoint file.
================================================
FILE: tutorials/full_finetune_example.py
================================================
"""
This script is meant to be the simplest possible starting point for full finetuning a GPT model using lightning fabric with code (not CLI).
- no checkpoints
- no out dir
- no precision
- no resume
- no train/eval args (or any args in general)
- no logger (only to terminal)
- no grad accumulation
and no other fancy stuff.
To add all the above stuff, you can slowly add them in yourself by looking at the code in litgpt/finetune/full.py or the docs for litgpt/fabric.
"""
import os
import lightning as L
import torch
import torch.nn as nn
from litgpt.data import Alpaca
from litgpt.model import GPT, Config
from litgpt.tokenizer import Tokenizer
from litgpt.utils import num_parameters
# training params/args
SEED = 1337
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # try also "stabilityai/stablelm-base-alpha-3b"!
BATCH_SIZE = 4
LR_WARMUP_STEPS = 100
MAX_STEPS = 601
def validate(model, val_dataloader):
model.eval()
loss = 0
with torch.no_grad():
for batch in val_dataloader:
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss += nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
fabric.print(f"Validation loss: {loss / len(val_dataloader)}")
def train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader):
for iter_num, batch in enumerate(train_dataloader):
input_ids, targets = batch["input_ids"], batch["labels"]
# get model preds (logits)
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
# get loss
targets = targets.reshape(-1)
loss = nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
# update weights
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# print train loss every 100 steps
if iter_num % 100 == 0 or iter_num == 0:
fabric.print(f"Train iter {iter_num} - loss {loss}")
# validate every 300 steps
if iter_num % 300 == 0 or iter_num == 0:
validate(model, val_dataloader)
model.train()
iter_num += 1
if iter_num >= MAX_STEPS:
break
def main(fabric):
fabric.seed_everything(SEED)
# setup data, make tokenizer and make dataloaders
data = Alpaca()
tokenizer = Tokenizer(checkpoint_dir=f"checkpoints/{MODEL_NAME}")
data.connect(tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_length=1024)
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
# print how many steps in an epoch
fabric.print(f"Steps in an epoch: {len(train_dataloader)}")
# setup model
config = Config.from_file(f"checkpoints/{MODEL_NAME}/model_config.yaml")
model = GPT(config)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
model = fabric.setup(model)
# setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.02, betas=(0.9, 0.95))
optimizer = fabric.setup_optimizers(optimizer)
# setup lr scheduler
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / LR_WARMUP_STEPS)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(MAX_STEPS - LR_WARMUP_STEPS))
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[LR_WARMUP_STEPS])
# Start training!!!
train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader)
if __name__ == "__main__":
# check that the model exists (downloaded to ./checkpoints/)
if not os.path.exists(f"checkpoints/{MODEL_NAME}"):
print(f"Model {MODEL_NAME} not found. Please download it using `litgpt download --repo {MODEL_NAME}`")
exit()
### Setup and launch
fabric = L.Fabric(devices="auto", strategy="auto")
fabric.launch(main)
================================================
FILE: tutorials/inference.md
================================================
# Inference
We demonstrate how to run inference (next token prediction) with the GPT base model in the [`litgpt generate`](../litgpt/generate/base.py) command:
```bash
litgpt generate stabilityai/stablelm-base-alpha-3b \
--prompt "Hello, my name is"
```
Output:
```text
Hello, my name is Levi Durrer, I'm an Austrian journalist - Chairman of the Press Blair Party, with 37 years in the Press Blair International, and two years in the Spectre of Austerity for the other. I'm crossing my fingers that you will feel
```
The script assumes you have downloaded and converted the weights as described [here](download_model_weights.md).
This will run the 3B pre-trained model and require ~7 GB of GPU memory using the `bfloat16` datatype.
## Run interactively
You can also chat with the model interactively:
```bash
litgpt chat stabilityai/stablelm-tuned-alpha-3b
```
This script can work with any checkpoint. For the best chat-like experience, we recommend using it with a checkpoints
fine-tuned for chatting such as `stabilityai/stablelm-tuned-alpha-3b` or `togethercomputer/RedPajama-INCITE-Chat-3B-v1`.
> [!TIP]
> Use `--multiline true` to work with inputs that span multiple lines.
## Run a large model on one smaller device
Check out our [quantization tutorial](quantize.md).
## Run a large model on multiple smaller devices
We offer two scripts to leverage multiple devices for inference.
### [`litgpt generate_sequentially`](../litgpt/generate/sequentially.py)
Allows you to run models that wouldn't fit in a single card by partitioning the transformer blocks across all your devices and running them sequentially.
For instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations.
With 80 transformer layers, we could partition them across 8, 5, 4, or 2 devices.
```shell
litgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \
--max_new_tokens 256 \
--num_samples 2
```
Using A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable.
| Devices | Max GPU RAM | Token/sec |
|---------|-------------|-----------|
| 2 | OOM | - |
| 4 | 35.64 GB | 7.55 |
| 5 | 28.72 GB | 7.49 |
| 8 | 18.35 GB | 7.47 |
Note that the memory usage will also depend on the `max_new_tokens` value used.
The script also supports quantization, using 4-bit precision, we can now use 2 GPUs
```shell
litgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \
--max_new_tokens 256 \
--num_samples 2 \
--quantize bnb.nf4-dq
```
| Devices | Max GPU RAM | Token/sec |
|---------|-------------|-----------|
| 2 | 20.00 GB | 8.63 |
| 4 | 10.80 GB | 8.23 |
| 5 | 8.96 GB | 8.10 |
| 8 | 6.23 GB | 8.18 |
Smaller devices can also be used to run inference with this technique.
### [`litgpt generate_tp`](../litgpt/generate/tp.py)
Uses tensor parallelism (TP) to run models that wouldn't fit in a single card by sharding the MLP and Attention QKV linear layers across all your devices.
For instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations.
The requirement is that the intermediate size (for the MLP) and the QKV size (for attention) is divisible by the number of devices.
With an intermediate size of 28672, we can use 2, 4, 7, or 8 devices. With a QKV size of 10240 we can use 2, 4, 5, or 8 devices.
Since the script is configured to shard both, the intersection is used: we can only use 2, 4, or 8 devices.
```shell
litgpt generate_tp meta-llama/Llama-2-70b-chat-hf \
--max_new_tokens 256 \
--num_samples 2
```
Using A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable.
| Devices | Max GPU RAM | Token/sec |
|---------|-------------|-----------|
| 2 | OOM | - |
| 4 | 35.46 GB | 9.33 |
| 8 | 18.19 GB | 8.61 |
Note that the memory usage will also depend on the `max_new_tokens` value used.
The script also supports quantization, using 4-bit precision, we can now use 2 GPUs
```shell
litgpt generate_tp meta-llama/Llama-2-70b-chat-hf \
--max_new_tokens 256 \
--num_samples 2 \
--quantize bnb.nf4-dq
```
| Devices | Max GPU RAM | Token/sec |
|---------|-------------|-----------|
| 2 | 19.79 GB | 6.72 |
| 4 | 10.73 GB | 6.48 |
| 8 | 6.15 GB | 6.20 |
Smaller devices can also be used to run inference with this technique.
================================================
FILE: tutorials/mkdocs.yml
================================================
site_name: LitGPT Tutorials
plugins:
- pagetree
theme:
name: material
================================================
FILE: tutorials/oom.md
================================================
## Dealing with out-of-memory (OOM) errors
If you got this error while running a script
```bash
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacity of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process
has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory
is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
```
it means that your GPU memory size wasn't big enough for the model and script configuration.
Here's a few things you can try:
### Reduce the micro batch size
Adjust the `--train.micro_batch_size` argument in the fine-tuning and pretraining scripts. This variable determines the number of samples loaded per iteration.
A smaller value will simply load fewer samples simultaneously. The minimum value is 1.
Experiment with different micro batch sizes to find a balance between memory consumption and computational efficiency. Smaller micro batch sizes consume less memory but may result in slower training convergence. Conversely, larger micro batch sizes require more memory but can accelerate training speed.
### Reduce the model's context length
The context length (`block_size` in the code) plays a significant role in running models with attention.
* The pretraining scripts are configured to use the full context length of the model to train.
* The finetuning scripts are configured to use the longest sample length of the training data to avoid allocating unnecessary memory (`--train.max_seq_length` argument).
If that's longer than the model's context length, an error is raised. If you try to run a batch that is longer than this, an error is raised.
However, your hardware may not support such large context lengths. Here's what you can do:
* For the pretraining scripts, you can simply reduce the `Config(block_size=...)` value.
* For the finetuning scripts, you can trim the length of the samples in your dataset.
All the finetuning scripts expose a `--data.max_seq_length=...` argument. This might also be useful in cases where
sample lengths are highly unbalanced, as the presence of a single very long sample would incur a larger memory usage for all other
shorter samples. For example, the median length of the samples in Alpaca is 110 tokens. Truncating the Alpaca dataset to 256 max tokens reduces the memory requirements of a Falcon 7B model from 23.52 GB to 15.73 GB. For more information about the dataset truncation, please see the *Truncating datasets* section in the [prepare_dataset.md](prepare_dataset.md) tutorial.
Keep in mind that reducing the context length will affect the modelling performance on text sequences longer than the limit.
### Use lower precision
Our scripts expose the `--precision` argument, this directly impacts the memory usage.
Using true lower precision (`16-true`, `bf16-true`) reduces the memory usage by half compared to `32-true`, however,
the model might start producing NaNs due to the limited range of representable values.
Mixed precision training (`16-mixed`, `bf16-mixed`) provides better stability but offers limited memory reduction.
### Do sharding across multiple GPUs
For exceptionally large models, the aforementioned techniques might still not suffice. If you have multiple GPUs available,
you can trade off memory for speed by changing the `--devices 1` argument in the scripts. Enabling this option enables a parallelism technique (FSDP), sharding the memory across different GPUs.
The default configuration already uses activation checkpointing, but you can enable CPU offloading by changing the `cpu_offload=False` argument in the scripts.
### Try a different optimizer
Our scripts use the [`AdamW` optimizer](https://pytorch.org/docs/main/generated/torch.optim.AdamW.html).
It maintains 2 states for each trainable parameter of the model, meaning that the optimizer memory is double compared to
an optimizer like [`SGD`](https://pytorch.org/docs/main/generated/torch.optim.SGD.html).
You can try replacing it with your optimizer of choice that is lighter in memory requirements. Keep in mind that different optimizers have distinct optimization behaviors, so it's essential to assess their impact on the training process and model performance.
An example would be the recently published [Sophia](https://arxiv.org/abs/2305.14342) or [Lion](https://arxiv.org/abs/2302.06675) optimizers.
This suggestion is particularly relevant for pretraining, as the trainable parameters in the model represent a small
subset of the total in the fine-tuning scripts.
================================================
FILE: tutorials/prepare_dataset.md
================================================
# Preparing Datasets
Below is a table of all datasets that are currently supported in LitGPT:
| Name | Task | Size | Reference Repo | Paper / Blog | Data License |
|--------------|-------------|---------------------|--------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Alpaca | Finetuning | 51,759 samples | [URL](https://github.com/tatsu-lab/stanford_alpaca) | [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html) | Attribution-NonCommercial 4.0 International, [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html) |
| Alpaca-2k | Finetuning | 2000 samples | [URL](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) | See Alpaca above | See Alpaca Above |
| Alpaca-GPT4 | Finetuning | 52,002 samples | [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | [URL](https://arxiv.org/abs/2304.03277) | Attribution-NonCommercial 4.0 International, [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/DATA_LICENSE) |
| Alpaca Libre | Finetuning | 55,370 samples | [URL](https://github.com/mobarski/alpaca-libre) | - | CC0/MIT, [URL](https://github.com/mobarski/alpaca-libre) |
| Deita | Finetuning | 9,500 samples | [URL](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft/tree/main/data) | [URL](https://arxiv.org/abs/2312.15685) | MIT [URL](https://huggingface.co/datasets/hkust-nlp/deita-10k-v0/blob/main/README.md) |
| Dolly | Finetuning | 15,011 samples | [URL](https://github.com/databrickslabs/dolly/tree/master/data) | [URL](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | CC-BY-SA, [URL](https://github.com/databrickslabs/dolly#model-overview) |
| FLAN | Finetuning | 1,753,240 samples | [UR](https://huggingface.co/datasets/Muennighoff/flan) | [URL](https://blog.research.google/2023/02/the-flan-collection-advancing-open.html) | Subset dependent |
| LongForm | Finetuning | 23,652 samples | [URL](https://github.com/akoksal/LongForm) | [URL](https://arxiv.org/abs/2304.08460) | No information provided and subset-dependent, [URL](https://github.com/akoksal/LongForm) |
| LIMA | Finetuning | 1,084 samples | [URL](https://huggingface.co/datasets/GAIR/lima) | [URL](https://arxiv.org/abs/2305.11206) | "If the source data of LIMA has a stricter license than CC BY-NC-SA, the LIMA dataset follows the same. Otherwise, it follows the CC BY-NC-SA license", [URL](https://huggingface.co/datasets/GAIR/lima#license) |
| OpenWeb Text | Pretraining | 8,013,769 documents | [URL](https://github.com/jcpeterson/openwebtext) | [URL](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) | Unspecified |
| TinyLlama | Pretraining | 1 T tokens | [URL](https://github.com/jzhang38/TinyLlama) | [URL](https://arxiv.org/abs/2401.02385) | |
| TinyStories | Pretraining | 4,967,871 stories | [URL](https://huggingface.co/datasets/roneneldan/TinyStories) | [URL](https://arxiv.org/abs/2305.07759) | CDLA-Sharing-1.0 |
## Preparation
The steps here only need to be done once before preparing the finetuning datasets in the following subsections:
1. Follow the instructions in the [README](../README.md) to install the dependencies.
2. Download and convert the weights following our [guide](download_model_weights.md).
For the following examples, we will focus on finetuning with the `litgpt finetune_lora` command and use a Falcon 7B model.
However, the same steps apply to all other models and finetuning scripts.
Please read the [tutorials/finetune_*.md](.) documents for more information about finetuning models.
> [!IMPORTANT]
> By default, the maximum sequence length is obtained from the model configuration file. In case you run into out-of-memory errors, especially in the cases of LIMA and Dolly,
> you can try to lower the context length by setting the `--train.max_seq_length` parameter, for example, `litgpt finetune lora --train.max_seq_length 256`. For more information on truncating datasets, see the *Truncating datasets* section in the Alpaca section near the top of this article.
### Alpaca
The Alpaca dataset consists of 52,000 instructions and demonstrations produced by OpenAI's text-davinci-003 engine. This data is used in instruction-tuning, helping improve the performance of language models to follow instructions.
In its development, the creators leveraged the data generation methodology from the [Self-Instruct framework](https://github.com/yizhongw/self-instruct).
The original [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) dataset can be used as follows:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Alpaca
```
> [!TIP]
> Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.
#### Truncating datasets
By default, the finetuning scripts will determine the size of the longest tokenized sample in the dataset to determine the block size. However, if you are willing to truncate a few examples in the training set, you can reduce the computational resource requirements significantly. For instance you can set a sequence length threshold via `--train.max_seq_length`. We can determine an appropriate maximum sequence length by considering the distribution of the data sample lengths shown in the histogram below.
In this case, a cut-off of 256 may be a reasonable choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Alpaca \
--train.max_seq_length 256
```
For comparison, the Falcon 7B model requires 23.52 GB of memory for the original Alpaca dataset and 15.73 GB of memory for the truncated Alpaca dataset when finetuning with LoRA using a micro batchsize of 1 and bfloat-16 precision.
### Alpaca-2k
[Alpaca-2k](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) is a smaller, 2000-sample subset of Alpaca described above.
```bash
litgpt finetune_lora "tiiuae/falcon-7b" \
--data Alpaca2k
```
> [!TIP]
> Use `litgpt_finetune --data.help Alpaca2k` to list additional dataset-specific command line options.
The Alpaca-2k dataset distribution is shown below.
### Alpaca-GPT4
The Alpaca-GPT4 was built by using the prompts of the original Alpaca dataset and generate the responses via GPT 4. The
dataset consists of 52,000 instructions and responses.
The original [Alpaca-GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset can be used as follows:
```bash
litgpt finetune lora "tiiuae/falcon-7b" \
--data AlpacaGPT4
```
> [!TIP]
> Use `litgpt_finetune --data.help AlpacaGPT4` to list additional dataset-specific command line options.
The Alpaca-GPT4 dataset distribution is shown below.
### Alpaca Libre
[Alpaca Libre](https://github.com/mobarski/alpaca-libre) is a reimplementation or alternative to Alpaca using the same formatting.
To use Alpaca Libre instead of the original Alpaca dataset, use the following command:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Alpaca \
--data.file_url "https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json" \
--data.file_name "alpaca_libre_data_cleaned_archive.json"
```
> [!TIP]
> Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.
The Alpaca Libre dataset distribution is shown below.
You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 256 may be a good choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Alpaca \
--data.file_url "https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json" \
--data.file_name "alpaca_libre_data_cleaned_archive.json" \
--train.max_seq_length 256
```
### Deita
The Deita dataset (short for Data-Efficient Instruction Tuning for Alignment) is a collection of 9500 prompts and responses, as described in the [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) paper.
Using Falcon 7b as an example, we can use the dataset as follows:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Deita
```
> [!TIP]
> Use `litgpt finetune --data.help Deita` to list additional dataset-specific command line options.
Deita contains multiturn conversations. By default, only the first instruction-response pairs from
each of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions
and responses, set `--data.include_multiturn_conversations True`, which will include all multiturn conversations as regular
prompt-response pairs. Considering the multiturn-answers, the dataset consists of 209,272 prompt-response pairs.
The Deita dataset distribution without including multit-turn conversations is shown below.
The Deita dataset distribution including multit-turn conversations is depicted in the following histogram.
You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Deita \
--train.max_seq_length 512
```
### Dolly
The Dolly dataset is a publicly available collection of 15k instruction-following entries created by Databricks. It spans multiple behavioral domains, as described in the [InstructGPT paper](https://arxiv.org/abs/2203.02155) paper. These include areas like brainstorming, classification, closed QA, content creation, information retrieval, open QA, and summary generation.
The usage is similar to the Alpaca dataset described above. Using Falcon 7b as an example, we can use the dataset as follows:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Dolly
```
> [!TIP]
> Use `litgpt finetune --data.help Dolly` to list additional dataset-specific command line options.
The Dolly dataset distribution is shown below.
You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data Dolly \
--train.max_seq_length 256
```
### LongForm
LongForm is a semi-synthetic dataset based on raw text corpora for which the instructions were generated via an LLM. For more details about the instruction-generation process, please refer to the [LongForm research paper](https://arxiv.org/abs/2304.08460) by Köksal et al. According to the research paper, a Llama 7B model trained on LongForm achieves substantially better performance than the same Llama model trained on the 2x larger Alpaca dataset.
LongForm consists of 23,652 training samples, 2,042 validation samples, and 2,045 test samples. (In LitGPT, the validation samples are currently not used.)
The more detailed dataset composition is as follows based on a table taken from the [dataset repository](https://github.com/akoksal/LongForm):
| **Type** | **Source** | **Number of Examples** |
|------------------------|----------------|------------------------|
| **Corpora** | C4 | 10,000 |
| | Wikipedia | 5,000 |
| **Structured Corpora** | Stack Exchange | 4,380 |
| | WikiHow | 2,500 |
| **Tasks** | NIv2 | 3,684 |
| | Big Bench | 600 |
| | BEA-GEC | 1,203 |
| | Enron | 372 |
| **Total** | | 27,739 |
| | | |
| **Train** | | 23,652 |
| **Validation** | | 2,042 |
| **Test** | | 2,045 |
License information is not provided but would depend on the individual subsets listed above.
The LongForm dataset distribution is shown below.
You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 1500 may be a good choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data LongForm \
--train.max_seq_length 1500
```
> [!TIP]
> Use `litgpt finetune --data.help LongForm` to list additional dataset-specific command line options.
### LIMA
The LIMA dataset is a collection of 1,000 carefully curated prompts and responses, as described in the [LIMA: Less Is More for Alignment](https://arxiv.org/abs/2305.11206) paper. The dataset is sourced from three community Q&A websites: Stack Exchange, wikiHow, and the Pushshift Reddit Dataset. In addition, it also contains prompts and answers written and collected by the authors of the LIMA paper.
The usage is similar to the Dolly dataset described above except that it requires an Hugging Face access token that you need to copy & paste from your Hugging Face account. Using Falcon 7b as an example, we can use the dataset as follows:
```bash
export HF_TOKEN="insert_your_huggingface_token_here"
litgpt finetune lora \
--data LIMA \
--checkpoint_dir "tiiuae/falcon-7b"
```
> [!TIP]
> Use `litgpt finetune --data.help LIMA` to list additional dataset-specific command line options.
LIMA contains a handful of multiturn conversations. By default, only the first instruction-response pairs from
each of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions
and responses, set `--data.include_multiturn_conversations True`.
The LIMA dataset distribution is shown below.
You may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data LIMA \
--train.max_seq_length 512
```
### FLAN
FLAN is a collection of several dataset subsets by Google. In particular, the provided script in LitGPT loads the subsets from
[here](https://huggingface.co/datasets/Muennighoff/flan).
By default, all subsets (1,386,050 samples) and validations sets (367,190 subsets) are combined into a single dataset:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data FLAN
```
However, you can also select individual subsets via comma-separated strings as follows:
```bash
litgpt finetune lora tiiuae/falcon-7b \
--data FLAN \
--data.subsets "aeslc_10templates,ag_news_subset_10templates,anli_r1_10templates"
```
> [!TIP]
> Use `litgpt finetune --data.help FLAN` to list additional dataset-specific command line options.
You can find a list of all 66 supported subsets [here](https://huggingface.co/datasets/Muennighoff/flan).
## Preparing Custom Datasets for Instruction Finetuning
The models in LitGPT expect datasets for instruction finetuning in the following format:
```text
[
{
"instruction": "Write a limerick about a
pelican.”,
"input": "",
"output": "There once was a pelican so fine,
\nHis beak was as colorful as
sunshine,\nHe would fish all day,\nIn
a very unique way,\nThis pelican was
truly divine!\n\n\n"
},
{
"instruction": "Identify the odd one out from
the group.",
"input": "Carrot, Apple, Banana, Grape",
"output": "Carrot\n\n"
},
]
```
(Note that depending on the task, the `"input"` text can be an empty string, as shown above.)
You can use your own data in LitGPT by either reading in a JSON file in the format shown above or by implementing a custom `DataModule`.
### Preparing Custom Datasets From a JSON File
You can prepare custom dataset using a JSON file where each row is a dictionary with these keys:
- `instruction`: Column which will describe the task.
- `input`: A string holding a special input value for the instruction. This applies to some samples, and in others, this is empty (empty string).
- `output`: The expected response
> If any of the fields are missing, then the script will fail to read the dataset.
Then simply run any of the finetuning scripts with this input:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data JSON \
--data.json_path path/to/your/data.json \
--data.val_split_fraction 0.1
```
You can also customize how the dataset is read by using these additional parameters
- `val_split_fraction`: The fraction of the data to split. Defaults to `0.1`
- `seed`: The seed value to reproduce the same random splits for train and test data.
- `mask_inputs`: Whether to mask the prompt section from the label (with `ignore_index`).
- `ignore_index`: The index to use for labels that should be ignored. Defaults to `-100` (used when `mask_inputs` is `True`).
To use the settings described above, you can add the respective command line arguments when calling the finetuning scripts as shown in the example below:
```bash
litgpt finetune_lora tiiuae/falcon-7b \
--data JSON \
--data.json_path path/to/your/data.json \
--data.val_split_fraction 0.1 \
--data.seed 42 \
--data.mask_inputs False \
--data.ignore_index -100
```
You can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split.
> [!TIP]
> Use `litgpt finetune --data.help JSON` to list additional dataset-specific command line options.
### Preparing Custom Datasets Using DataModule
If you don't have a JSON file following the format described in the previous section, the easiest way to prepare a new dataset is to copy and modify one of the existing data modules in LitGPT:
- [`litgpt/data/alpaca.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/alpaca.py) (if you plan to load a dataset from a JSON file);
- [`litgpt/data/lima.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/lima.py) (if you plan to load a dataset using the `datasets` Python library).
Note that you only need to modify a small fraction of the code file, namely the portion that downloads and formats the training data (see the `prepare_data` and `setup()` methods).
## Preparing Pretraining Datasets
In addition to the finetuning dataset described above, LitGPT also supports several datasets for pretraining. The pretraining datasets are described in more detail in the following separate tutorial documents:
- [Pretrain TinyLlama on Slimpajama and Starcoder](./pretrain_tinyllama.md)
================================================
FILE: tutorials/pretrain.md
================================================
# Pretrain LLMs with LitGPT
This document explains how to pretrain LLMs using LitGPT.
## Using the `litgpt pretrain` command
You can pretrain models in LitGPT using the `litgpt pretrain` API starting with any of the available architectures listed by calling `litgpt pretrain list` without any additional arguments:
> [!TIP]
> To install all required dependencies before pretraining, first run `pip install "litgpt[all]"`.
```bash
litgpt pretrain list
```
Shown below is an abbreviated list:
```
ValueError: Please specify --model_name . Available values:
Camel-Platypus2-13B
...
Gemma-2b
...
Llama-2-7b-hf
...
Mixtral-8x7B-v0.1
...
pythia-14m
```
For demonstration purposes, we can pretrain a small 14 million-parameter Pythia model on the small TinyStories dataset using the [debug.yaml config file](https://github.com/Lightning-AI/litgpt/blob/main/config_hub/pretrain/debug.yaml) as follows:
```bash
litgpt pretrain pythia-14m \
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/debug.yaml
```
## Pretrain on custom data
The simplest way to get started with pretraining on a small custom dataset is by using the `TextFiles` data module, which lets you pretrain a dataset from a folder containing plain text files.
> [!NOTE]
> This approach adds a beginning-of-sequence token at the beginning of each text file. However, it otherwise assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable in case a text file conists of multiple documents.
> [!WARNING]
> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size.
> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData`.
For instance, assume you stored a number of text files in a `custom_pretraining_dataset` folder (we recommend avoiding small files and concatenating them to files of at least 50 Mb for efficiency):
```bash
~ ls -lh custom_pretraining_data
total 3225M
-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_1.txt
-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_2.txt
-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_3.txt
-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_4.txt
-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_5.txt
...
```
You can then use the `TextFiles` API to pretrain a model (here a small `pythia-14m` model for illustration purposes) from scratch as follows:
```bash
litgpt download EleutherAI/pythia-14m \
--tokenizer_only true
litgpt pretrain pythia-14m \
--tokenizer_dir EleutherAI/pythia-14m \
--data TextFiles \
--data.train_data_path custom_pretraining_data \
--train.lr_warmup_steps=200 \
--optimizer AdamW \
--optimizer.lr 0.005
```
> [!TIP]
> Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.
## Continued pretraining on custom data
Often, it makes sense to adopt an existing pretrained model and further pretrain it on our own custom data. The existing pretrained model can be either our own pretrained model or a model downloaded from a model hub.
The following subsections illustrate three typical scenarioes:
1. Starting from a downloaded base model
2. Continuing the pretraining after interruption
3. Further pretraining on a different dataset
> [!NOTE]
> This approach assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable.
> [!WARNING]
> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size.
> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData --data.path path/to/your/data`.
### 1) Continued pretraining when starting from a downloaded base model
For instance, let's assume we download a Pythia model:
```bash
litgpt download EleutherAI/pythia-160m
```
Next, assume we have a custom dataset stored in text files similar to the *Pretrain on custom data* above. We can further pretrain the Pythia model via the `--initial_checkpoint_dir` setting as follows:
```bash
litgpt pretrain pythia-160m \
--initial_checkpoint_dir EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--out_dir ./new_pretrained_checkpoint \
--data TextFiles \
--data.train_data_path custom_pretraining_data \
--train.max_tokens 1_000_000
```
> [!TIP]
> Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.
### 2) Continued pretraining after interruption
In case a you interrupted a training run, you can continue it with the `--resume` option, for example:
```bash
litgpt pretrain pythia-160m \
--resume "auto" \
--tokenizer_dir EleutherAI/pythia-160m \
--out_dir ./new_pretrained_checkpoint \
--data TextFiles \
--data.train_data_path custom_pretraining_data \
--train.max_tokens 1_000_000
```
### 3) Continued pretraining on a new dataset
Suppose you pretrained a model using the examples above. To further pretrain the model on a new dataset, you first need to convert the pretrained checkpoint via the following command:
```bash
litgpt convert_pretrained_checkpoint ./new_pretrained_checkpoint/final ./new_pretrained_checkpoint_converted
```
Then, you can pretrain the converted model on the new dataset as follows:
```bash
litgpt pretrain pythia-160m \
--initial_checkpoint_dir ./new_pretrained_checkpoint_converted \
--tokenizer_dir EleutherAI/pythia-160m \
--out_dir ./new_pretrained_checkpoint_2 \
--data TextFiles \
--data.train_data_path custom_pretraining_data_2 \
--train.max_tokens 1_000_000
```
## Pretrain a 1.1B TinyLlama model
You can find an end-to-end LitGPT tutorial for pretraining a TinyLlama model using LitGPT [here](pretrain_tinyllama.md).
## Optimize LitGPT pretraining with Lightning Thunder
[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) is a source-to-source compiler for PyTorch, which is fully compatible with LitGPT. In experiments, Thunder resulted in a 40% speed-up compared to using regular PyTorch when finetuning a 7B Llama 2 model.
For more information, see the [Lightning Thunder extension README](https://github.com/Lightning-AI/lightning-thunder).
## Project templates
The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:
| | |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[
](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)
[
](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) |
| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)
[
](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | |
| |
================================================
FILE: tutorials/pretrain_tinyllama.md
================================================
# Pretrain TinyLlama
This tutorial will walk you through pretraining [TinyLlama](https://github.com/jzhang38/TinyLlama/).
> [!TIP]
> To get started with zero setup, clone the [TinyLlama studio on Lightning AI](https://lightning.ai/lightning-ai/studios/llm-pretrain-tinyllama-1-1b).
## What's TinyLlama?
[TinyLlama](https://github.com/jzhang38/TinyLlama/) is architecturally the same as Meta AI's LLama 2, but only has 1.1B parameters and is instead trained on multiple epochs on a mix of [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) and [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) datasets.
Here is a quick fact sheet:
| Name | Description |
|-------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Parameters | 1.1B |
| Model Size | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size: 5632 |
| Sequence Length | 2048 |
| Learning Rate | 4e-4 |
| Learning Rate Schedule | Cosine with 2000 warmup steps |
| Training Data | [SlimPajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) (893 GB), [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) (290 GB) |
| Combined Dataset Size | Around 950B tokens |
| Total Tokens During Training | 3 trillion (3 epochs) |
| Time to complete training | ~ 4 weeks with 64 A100 GPUs |
| Model FLOPs Utilization (MFU) | 52% |
(this table was sourced from the author's [README](https://github.com/jzhang38/TinyLlama/))
## Download datasets
You can download the data using git lfs:
```bash
# Make sure you have git-lfs installed (https://git-lfs.com):
sudo apt install git-lfs
```
```bash
git clone https://huggingface.co/datasets/cerebras/slimpajama-627b data/slimpajama-raw
git clone https://huggingface.co/datasets/bigcode/starcoderdata data/starcoderdata-raw
```
Around 1.2 TB of disk space is required to store both datasets.
## Prepare the datasets for training
In order to start pretraining litgpt on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `litdata` optimization pipeline and streaming dataset.
First, install additional dependencies for preprocessing:
```bash
pip install '.[all]'
```
You will need to have the tokenizer config available:
```bash
litgpt download meta-llama/Llama-2-7b-hf \
--access_token your_hf_token \
--tokenizer_only true
```
Then, run the preprocessing script for each dataset and split.
You will require **1.1 TB** of disk space for Starcoder and **2.5** TB of space for the SlimPajama dataset.
**Starcoder:**
```bash
python litgpt/data/prepare_starcoder.py \
--input_dir data/starcoderdata-raw \
--output_dir data/starcoder \
--tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf
```
**SlimPajama:**
```bash
python litgpt/data/prepare_slimpajama.py \
--input_dir data/slimpajama-raw/validation \
--output_dir data/slimpajama/val \
--tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf
python litgpt/data/prepare_slimpajama.py \
--input_dir data/slimpajama-raw/test \
--output_dir data/slimpajama/test \
--tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf
python litgpt/data/prepare_slimpajama.py \
--input_dir data/slimpajama-raw/train \
--output_dir data/slimpajama/train \
--tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf
```
If you want to run on a small slice of the datasets first, pass the flag `--fast_dev_run=true` to the commands above.
In the above we are assuming that you will be using the same tokenizer as used in LlaMA/TinyLlama, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.
## Pretraining
Running the pretraining script with its default settings requires at least 8 A100 GPUs.
```bash
litgpt pretrain --config config_hub/pretrain/tinyllama.yaml
```
> [!TIP]
> Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.
The script will save checkpoints periodically to the folder `out/`.
By default, the `pretrain` script will pretrain the model with FSDP in
`bfloat16` mixed precision and gradient accumulation.
Note that `pretrain` is not actually a model-specific training script, so feel free [try other configurations](../config_hub)
or change the model type and size by passing a different string to the model name argument, for example:
```shell
litgpt pretrain Gemma-2b
```
The currently supported model names can be listed by executing `litgpt pretrain` without any additional arguments.
Keep in mind that training with a single machine will take weeks. To speed up the process, you'll need access to a cluster.
Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/fundamentals/launch.html#launch-on-a-cluster)
to launch the script across machines:
- [Lightning AI](https://lightning.ai/docs/fabric/stable/guide/multi_node/cloud.html)
- [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
- [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
- [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)
The script exposes several hyperparameters you can tweak through the command line.
For instance, `--train.micro_batch_size` should be adjusted so the process will use the available
GPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed
[Dealing with out-of-memory (OOM) errors](oom.md) guide.
Last, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker.
LitGPT supports multiple experiment trackers including:
- **TensorBoard** (default): Local visualization with TensorBoard
- **CSV Logger**: Simple local logging to CSV files
- **WandB**: Cloud-based experiment tracking with Weights & Biases
- **MLflow**: MLflow experiment tracking
- **[LitLogger](https://github.com/Lightning-AI/LitLogger)**: Lightning.ai's native experiment tracking (set `--logger_name=litlogger`)
As an example, we included WandB (set `--logger_name=wandb`) to show how you can integrate any experiment tracking framework.
For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).
## Resume training
The checkpoints saved during pretraining contain all the information to resume if needed.
Simply rerun the script with the `--resume` argument added:
```bash
litgpt pretrain tiny-llama\
--config config_hub/pretrain/tinyllama.yaml \
--resume out/pretrain/tiny-llama/step-00060500
```
**Important:** Each checkpoint is a directory. Point to the directory, not the 'lit_model.pth' file inside of it.
> [!TIP]
> Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.
## Export checkpoints
After training is completed, you can convert the checkpoint to a format that can be loaded for evaluation, inference, finetuning etc.
```bash
litgpt convert_pretrained_checkpoint out/pretrain/tiny-llama/step-00060500 \
--output_dir checkpoints/tiny-llama/final
```
After conversion, the output folder will contain these files:
```
checkpoints/tiny-llama/final
├── model_config.yaml
├── lit_model.pth
├── tokenizer_config.json
├── tokenizer.json
└── tokenizer.model
```
You can then use this checkpoint folder to run [evaluation](evaluation.md), [inference](inference.md), [finetuning](finetune_lora.md) or [process the checkpoint further](convert_lit_models.md).
## Project templates
The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:
| | |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[
](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)
[
](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) |
| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)
[
](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | |
| |
================================================
FILE: tutorials/python-api.md
================================================
# LitGPT Python API
This is a work-in-progress draft describing the current LitGPT Python API (experimental and subject to change).
## Model loading
Use the `LLM.load` method to load a model from a LitGPT model checkpoint folder. For example, consider loading a Phi-2 model. If a given checkpoint directory `"microsoft/phi-2"` does not exist as a local checkpoint directory, the model will be downloaded automatically from the HF Hub (assuming that `"microsoft/phi-2"` is a valid repository name):
```python
from litgpt import LLM
llm_1 = LLM.load("microsoft/phi-2")
```
```
config.json: 100%|████████████████████████████████████████████████| 735/735 [00:00<00:00, 7.75MB/s]
generation_config.json: 100%|█████████████████████████████████████| 124/124 [00:00<00:00, 2.06MB/s]
model-00001-of-00002.safetensors: 100%|███████████████████████████| 5.00G/5.00G [00:12<00:00, 397MB/s]
model-00002-of-00002.safetensors: 100%|███████████████████████████| 564M/564M [00:01<00:00, 421MB/s]
model.safetensors.index.json: 100%|███████████████████████████████| 35.7k/35.7k [00:00<00:00, 115MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 21.5MB/s]
tokenizer_config.json: 100%|██████████████████████████████████████| 7.34k/7.34k [00:00<00:00, 80.6MB/s]
```
> [!NOTE]
> To get a list of all supported models, execute `litgpt download list` in the command line terminal.
If you attempt to load the model again, LitGPT will load this model from a local directory since it's already been downloaded:
```python
llm_2 = LLM.load("microsoft/phi-2")
```
If you created a pretrained or finetuned model checkpoint via LitGPT, you can load it in a similar fashion:
```python
my_llm = LLM.load("path/to/my/local/checkpoint")
```
## Generate/Chat
Generate output using the `.generate` method:
```python
from litgpt import LLM
llm = LLM.load("microsoft/phi-2")
text = llm.generate("What do Llamas eat?", top_k=1, max_new_tokens=30)
print(text)
```
```
Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract
```
Alternatively, stream the response one token at a time:
```python
result = llm.generate("hi", stream=True)
for e in result:
print(e, end="", flush=True)
```
```
Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract
```
## Saving models
After finetuning or modifying a model, you can save it to disk using the `.save()` method:
```python
from litgpt import LLM
llm = LLM.load("microsoft/phi-2")
# ... perform finetuning or modifications ...
llm.save("path/to/save/directory")
```
The saved checkpoint can then be loaded later:
```python
llm = LLM.load("path/to/save/directory")
```
## Random weights
To start with random weights, for example, if you plan a pretraining script, initialize the model with `init="random"`. Note that this requires passing a `tokenizer_dir` that contains a valid tokenizer file.
```python
from litgpt.api import LLM
llm = LLM.load("pythia-160m", init="random", tokenizer_dir="EleutherAI/pythia-160m")
```
## Multi-GPU strategies
By default, the model is loaded onto a single GPU. Optionally, you can use the `.distribute()` method with the "sequential" or "tensor_parallel" `generate_strategy` settings.
### Sequential strategy
The `generate_strategy="sequential"` setting loads different parts of the models onto different GPUs. The goal behind this strategy is to support models that cannot fit into single-GPU memory. (Note that if you have a model that can fit onto a single GPU, this sequential strategy will be slower.)
```python
from litgpt.api import LLM
llm = LLM.load(
"microsoft/phi-2",
distribute=None
)
llm.distribute(
generate_strategy="sequential",
devices=4, # Optional setting, otherwise uses all available GPUs
fixed_kv_cache_size=256 # Optionally use a small kv-cache to further reduce memory usage
)
```
```
Using 4 devices
Moving '_forward_module.transformer.h.31' to cuda:3: 100%|██████████| 32/32 [00:00<00:00, 32.71it/s]
```
After initializing the model, the model can be used via the `generate` method similar to the default `generate_strategy` setting:
```python
text = llm.generate("What do llamas eat?", max_new_tokens=100)
print(text)
```
```
Llamas are herbivores and their diet consists mainly of grasses, plants, and leaves.
```
### Tensor parallel strategy
The sequential strategy explained in the previous subsection distributes the model sequentially across GPUs, which allows users to load models that would not fit onto a single GPU. However, due to this method's sequential nature, processing is naturally slower than parallel processing.
To take advantage of parallel processing via tensor parallelism, you can use the `generate_strategy="tensor_parallel" setting. However, this method has downsides: the initial setup may be slower for large models, and it cannot run in interactive processes such as Jupyter notebooks.
```python
from litgpt.api import LLM
if __name__ == "__main__":
llm = LLM.load(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
distribute=None
)
llm.distribute(generate_strategy="tensor_parallel", devices=4)
print(llm.generate(prompt="What do llamas eat?"))
print(llm.generate(prompt="What is 1+2?", top_k=1))
```
## Speed and resource estimates
Use the `.benchmark()` method to compare the computational performance of different settings. The `.benchmark()` method takes the same arguments as the `.generate()` method. For example, we can estimate the speed and GPU memory consumption as follows (the resulting numbers were obtained on an A10G GPU):
```python
from litgpt.api import LLM
from pprint import pprint
llm = LLM.load(
model="microsoft/phi-2",
distribute=None
)
llm.distribute(fixed_kv_cache_size=500)
text, bench_d = llm.benchmark(prompt="What do llamas eat?", top_k=1, stream=True)
print(text)
pprint(bench_d)
# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized
# digestive system that allows them to efficiently extract nutrients from plant material.
# Using 1 device(s)
# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material.
# {'Inference speed in tokens/sec': [17.617540650112936],
# 'Seconds to first token': [0.6533610639999097],
# 'Seconds total': [1.4758019020000575],
# 'Tokens generated': [26],
# 'Total GPU memory allocated in GB': [5.923729408]}
```
To get more reliably estimates, it's recommended to repeat the benchmark for multiple iterations via `num_iterations=10`:
```python
text, bench_d = llm.benchmark(num_iterations=10, prompt="What do llamas eat?", top_k=1, stream=True)
print(text)
pprint(bench_d)
# Using 1 device(s)
# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material.
# {'Inference speed in tokens/sec': [17.08638672485105,
# 31.79908547222976,
# 32.83646959864293,
# 32.95994240022436,
# 33.01563039816964,
# 32.85263413816648,
# 32.82712094713627,
# 32.69216141907453,
# 31.52431714347663,
# 32.56752130561681],
# 'Seconds to first token': [0.7278506560005553,
# 0.022963577999689733,
# 0.02399449199947412,
# 0.022921959999621322,
# ...
```
As one can see, the first iteration may take longer due to warmup times. So, it's recommended to discard the first iteration:
```python
for key in bench_d:
bench_d[key] = bench_d[key][1:]
```
For better visualization, you can use the `benchmark_dict_to_markdown_table` function
```python
from litgpt.api import benchmark_dict_to_markdown_table
print(benchmark_dict_to_markdown_table(bench_d_list))
```
| Metric | Mean | Std Dev |
|-------------------------------------|-----------------------------|-----------------------------|
| Seconds total | 0.80 | 0.01 |
| Seconds to first token | 0.02 | 0.00 |
| Tokens generated | 26.00 | 0.00 |
| Inference speed in tokens/sec | 32.56 | 0.50 |
| Total GPU memory allocated in GB | 5.92 | 0.00 |
# PyTorch Lightning Trainer support
You can use the LitGPT `LLM` class with the [PyTorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to pretrain and finetune models.
The examples below show the usage via a simple 160 million parameter model for demonstration purposes to be able to quickly try it out. However, you can replace the `EleutherAI/pythia-160m` model with any model supported by LitGPT (you can find a list of supported models by executing `litgpt download list` or visiting the [model weight docs](download_model_weights.md)).
## Step 1: Define a `LightningModule`
First, we define a `LightningModule` similar to what we would do when working with other types of neural networks in PyTorch Lightning:
```python
import torch
import litgpt
from litgpt import LLM
from litgpt.data import Alpaca2k
import lightning as L
class LitLLM(L.LightningModule):
def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):
super().__init__()
self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)
self.trainer_ckpt_path = trainer_ckpt_path
def setup(self, stage):
self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)
def training_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch):
logits, loss = self.llm(input_ids=batch["input_ids"], target_ids=batch["labels"])
self.log("validation_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
warmup_steps = 10
optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
return [optimizer], [scheduler]
```
In the code example above, note how we set `distribute=None` in `llm.load()` in the `__init__` method. This step is necessary because we want to let the PyTorch Lightning Trainer handle the GPU devices. We then call `self.llm.trainer_setup` in the `setup()` method, which adjusts the LitGPT settings to be compatible with the Trainer. Other than that, everything else looks like a standard `LightningModule`.
Next, we have a selection of different use cases, but first, let's set some general settings to specify the batch size and gradient accumulation steps:
```python
batch_size = 8
accumulate_grad_batches = 1
```
For larger models, you may want to decrease the batch size and increase the number of accumulation steps. (Setting `accumulate_grad_batches = 1` effectively disables gradient accumulation, and it is only shown here for reference in case you wish to change this setting.)
## Step 2: Using the Trainer
### Use case 1: Pretraining from random weights
In case you plan to train a model from scratch (not recommended over finetuning because training a model from scratch in general requires substantial time and resources), you can do it as follows:
```python
# Create model with random as opposed to pretrained weights
llm = LLM.load("EleutherAI/pythia-160m", tokenizer_dir="EleutherAI/pythia-160m", init="random")
llm.save("pythia-160m-random-weights")
del llm
lit_model = LitLLM(checkpoint_dir="pythia-160m-random-weights", tokenizer_dir="EleutherAI/pythia-160m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
```
### Use case 2: Continued pretraining or finetuning a downloaded model
The continued pretraining or finetuning from a downloaded model checkpoint is similar to the example above, except that we can skip the initial steps of instantiating a model with random weights.
```python
lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m")
data = Alpaca2k()
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
```
### Use case 3: Resume training from Trainer checkpoint
Suppose you trained a model and decide to follow up with a few additional training rounds. This can be achieved as follows by loading an existing Trainer checkpoint:
```python
import os
def find_latest_checkpoint(directory):
latest_checkpoint = None
latest_time = 0
for root, _, files in os.walk(directory):
for file in files:
if file.endswith('.ckpt'):
file_path = os.path.join(root, file)
file_time = os.path.getmtime(file_path)
if file_time > latest_time:
latest_time = file_time
latest_checkpoint = file_path
return latest_checkpoint
lit_model = LitLLM(checkpoint_dir="EleutherAI/pythia-160m", trainer_ckpt_path=find_latest_checkpoint("lightning_logs"))
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
```
### Use case 4: Resume training after saving a checkpoint manually
This example illustrates how we can save a LitGPT checkpoint from a previous training run that we can load and use later. Note that compared to using the Trainer checkpoint in the previous section, the model saved via this approach also contains the tokenizer and other relevant files. Hence, this approach does not require the original `"EleutherAI/pythia-160m"` model checkpoint directory.
```python
lit_model.llm.save("finetuned_checkpoint")
del lit_model
lit_model = LitLLM(checkpoint_dir="finetuned_checkpoint")
data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)
trainer = L.Trainer(
devices=1,
accelerator="cuda",
max_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
precision="bf16-true",
)
trainer.fit(lit_model, data)
lit_model.llm.model.to(lit_model.llm.preprocessor.device)
lit_model.llm.generate("hello world")
```
================================================
FILE: tutorials/quantize.md
================================================
# Quantize the model
This document provides different strategies for quantizing the various models available in LitGPT to reduce GPU memory usage, which is useful for running larger models on certain GPU hardware.
**All the examples below were run on an A100 40GB GPU with CUDA 12.1.**
> [!NOTE]
> Quantization also supports finetuning via [QLoRA](finetune_lora.md)
## Baseline
It's useful to start with a baseline to have a reference point for memory savings via the various quantization methods.
```bash
litgpt generate tiiuae/falcon-7b \
--precision 32-true \
--max_new_tokens 256
...
Time for inference 1: 6.93 sec total, 36.96 tokens/sec.
Memory used: 28.95 GB
```
First, using a lower precision compared to 32-bit float can result in two times reduced memory consumption. You can either try setting `--precision 16-true` for regular 16-bit precision or `--precision bf16-true` if your GPU supports brain-float 16-bit precision. ([This brief video](https://lightning.ai/courses/deep-learning-fundamentals/9.0-overview-techniques-for-speeding-up-model-training/unit-9.1-accelerated-model-training-via-mixed-precision-training/) explains the difference between regular 16-bit and bf16-bit precision.)
In short, when `--precision bf16-true` or `--precision 16-true` is used, the model weights will automatically be converted and consume less memory.
However, this might not be enough for large models or when using GPUs with limited memory.
```bash
litgpt generate tiiuae/falcon-7b \
--precision bf16-true \
--max_new_tokens 256
...
Time for inference 1: 5.37 sec total, 47.66 tokens/sec.
Memory used: 14.50 GB
```
To reduce the memory requirements further, LitGPT supports several quantization techniques, which are shown below.
> [!TIP]
> Most quantization examples below also use the `--precision bf16-true` setting explained above. If your GPU does not support `bfloat16`, you can change it to `--precision 16-true`.
## `bnb.nf4`
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.
> [!IMPORTANT]
> `bitsandbytes` only supports `CUDA` devices and the `Linux` operating system.
> Windows users should use [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl).
Uses the normalized float 4 (nf4) data type. This is recommended over "fp4" based on the paper's experimental results and theoretical analysis.
```bash
pip install bitsandbytes
litgpt generate tiiuae/falcon-7b \
--quantize bnb.nf4 \
--precision bf16-true \
--max_new_tokens 256
...
Time for inference 1: 6.80 sec total, 37.62 tokens/sec
Memory used: 5.72 GB
```
## `bnb.nf4-dq`
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.
"dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants.
In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
```bash
pip install bitsandbytes
litgpt generate tiiuae/falcon-7b \
--quantize bnb.nf4-dq \
--precision bf16-true \
--max_new_tokens 256
...
Time for inference 1: 8.09 sec total, 30.87 tokens/sec
Memory used: 5.38 GB
```
## `bnb.fp4`
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.
Uses pure FP4 quantization.
```bash
pip install bitsandbytes
litgpt generate tiiuae/falcon-7b \
--quantize bnb.fp4 \
--precision bf16-true \
--max_new_tokens 256
...
Time for inference 1: 6.92 sec total, 36.98 tokens/sec
Memory used: 5.72 GB
```
## `bnb.fp4-dq`
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.
"dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants.
In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
```bash
pip install bitsandbytes
litgpt generate tiiuae/falcon-7b \
--quantize bnb.fp4-dq \
--precision bf16-true \
--max_new_tokens 256
...
Time for inference 1: 10.02 sec total, 25.54 tokens/sec
Memory used: 5.38 GB
```
## `bnb.int8`
Enabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2110.02861) to learn more about how it works.
```bash
pip install bitsandbytes
litgpt generate tiiuae/falcon-7b \
--quantize bnb.int8 \
--precision 16-true \
--max_new_tokens 256
...
Time for inference 1: 20.22 sec total, 12.66 tokens/sec
Memory used: 8.70 GB
```
================================================
FILE: tutorials/resource-tables.md
================================================
# Resource Tables
- Last updated: 10/20/2023
- LitGPT version: commit 8641822
- Hardware: NVIDIA A100-SXM4-40GB
- OS: Ubuntu 22.04.3 LTS (x86_64)
- Nvidia driver version: 525.125.06
- Relevant libraries
- PyTorch 2.1.0+cu121
- Bitsandbytes 0.41.1
This document provides an overview and examples of hardware requirements when running models in LitGPT.
For additional tips on lowering the GPU memory footprint, please also see the [Dealing with out-of-memory (OOM) errors](oom.md) document.
All experiments were run using 16-bit brain floating point precision (`--precision bf16-true`). If your GPU does not support brain floating point precision, you can use regular 16-bit floating point precision (`--precision 16-true`).
All experiments were conducted using the Alpaca dataset with its default length. Note that due to different tokenizers being used by the different models, the number of tokens in the longest training example differs based on the model:
- phi1.5: 1044 tokens
- StableLM Alpha: 1034 tokens
- Llama 2: 1304 tokens
- Falcon 1079 tokens
Note that the number of tokens in the training set does not affect the supported context width (block size) of the models, which is as follows:
- phi1.5: 2048 tokens
- StableLM 3B Alpha: 4096 tokens
- Llama 2: 4048 tokens
- Falcon: 2048 tokens
- CodeLlama 13B: 16384 tokens
## Finetuning with LoRA on 1 GPU
The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_lora` command.
| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |
|-------|----------------|--------------|-----------------|----------------------|-------------|--------------------|
| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 4.82 GB | 1.62 min |
| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 3.78 GB | 1.77 min |
| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 3.72 GB | 1.87 min |
| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 6.76 GB | 1.65 min |
| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 10.68 GB | 1.70 min |
| | | | | | | |
| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 9.69 GB | 1.24 min |
| 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,097,152 | 6.35 GB | 1.82 min |
| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,097,152 | 6.19 GB | 1.87 min |
| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 12.10 GB | 1.33 min |
| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 16.92 GB | 1.50 min |
| | | | | | | |
| 7 B | Llama 2 | None | 1 | 4,194,304 | 21.30 GB | 2.36 min |
| 7 B | Llama 2 | bnb.nf4 | 1 | 4,194,304 | 14.14 GB | 3.68 min |
| 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,194,304 | 13.84 GB | 3.83 min |
| 7 B | Llama 2 | None | 2 | 4,194,304 | 29.07 GB | 2.52 min |
| 7 B | Llama 2 | None | 4 | 4,194,304 | OOM | - |
| | | | | | | |
| 13 B | Llama 2 | None | 1 | 6,553,600 | 38.12 GB | 3.19 min |
| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 23.14 GB | 6.38 min |
| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 22.55 GB | 6.55 min |
| 13 B | Llama 2 | None | 2 | 6,553,600 | OOM | - |
| 13 B | Llama 2 | None | 4 | 6,553,600 | OOM | - |
| | | | | | | |
| 40 B | Falcon | None | 1 | 12,042,240 | OOM | - |
| 40 B | Falcon | bnb.nf4 | 1 | 12,042,240 | OOM | - |
| 40 B | Falcon | bnb.nf4-dq | 1 | 12,042,240 | OOM | - |
## Finetuning with Adapter on 1 GPU
The following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_adapter` command.
| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |
|------|----------------|--------------|-----------------|----------------------|-------------|--------------------|
| 3 B | StableLM Alpha | None | 1 | 573,888 | 9.10 GB | 0.74 min |
| 3 B | StableLM Alpha | bnb.nf4 | 1 | 573,888 | 5.65 GB | 1.38 min |
| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 573,888 | 5.48 GB | 1.46 min |
| | | | | | | |
| 7 B | Llama 2 | None | 1 | 1,229,760 | 19.98 GB | 1.50 min |
| 7 B | Llama 2 | bnb.nf4 | 1 | 1,229,760 | 12.68 GB | 2.93 min |
| 7 B | Llama 2 | bnb.nf4-dq | 1 | 1,229,760 | 12.38 GB | 3.00 min |
The same config, but using the `litgpt finetune_adapter_v2` command.
| Size | Model | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |
|------|----------------|--------------|-----------------|----------------------|-------------|--------------------|
| 3 B | StableLM Alpha | None | 1 | 2,125,248 | 10.71 GB | 0.87 min |
| 3 B | StableLM Alpha | bnb.nf4 | 1 | 2,125,248 | 7.41 GB | 1.59 min |
| 3 B | StableLM Alpha | bnb.nf4-dq | 1 | 2,125,248 | 7.25 GB | 1.62 min |
| | | | | | | |
| 7 B | Llama 2 | None | 1 | 4,279,744 | 25.51 GB | 1.81 min |
| 7 B | Llama 2 | bnb.nf4 | 1 | 4,279,744 | 18.30 GB | 3.23 min |
| 7 B | Llama 2 | bnb.nf4-dq | 1 | 4,279,744 | 17.98 GB | 3.32 min |
## Finetuning with LoRA on Multiple GPUs
The following experiments were conducted on multiple A100 GPUs with a minibatch size of 128 using the `litgpt finetune_lora` command.
| Size | Model | Quantization | Microbatch size | Trainable parameters | GPU | Max GPU RAM | Time 1k iterations |
|-------|----------------|--------------|-----------------|----------------------|----------|-------------|--------------------|
| 1.3 B | phi-1.5 | None | 1 | 1,572,864 | 2 x A100 | 4.86 GB | 3.81 min |
| 1.3 B | phi-1.5 | bnb.nf4 | 1 | 1,572,864 | 2 x A100 | N/A | - |
| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 | 1,572,864 | 2 x A100 | N/A | - |
| 1.3 B | phi-1.5 | None | 2 | 1,572,864 | 2 x A100 | 5.05 GB | 3.63 min |
| 1.3 B | phi-1.5 | None | 4 | 1,572,864 | 2 x A100 | 5.88 GB | 3.64 min |
| | | | | | | | |
| 3 B | StableLM Alpha | None | 1 | 2,097,152 | 2 x A100 | 12.75 GB | 2.92 min |
| 3 B | StableLM Alpha | None | 2 | 2,097,152 | 2 x A100 | 12.94 GB | 3.06 min |
| 3 B | StableLM Alpha | None | 4 | 2,097,152 | 2 x A100 | 13.45 GB | 3.86 min |
| | | | | | | | - |
| 7 B | Llama 2 | None | 1 | 4,194,304 | 2 x A100 | 22.18 GB | 5.93 min |
| 7 B | Llama 2 | None | 2 | 4,194,304 | 2 x A100 | 22.47 GB | 6.48 min |
| 7 B | Llama 2 | None | 4 | 4,194,304 | 2 x A100 | 23.39 GB | 8.66 min |
| | | | | | | | |
| 13 B | Llama 2 | None | 1 | 6,553,600 | 2 x A100 | OOM | - |
| 13 B | Llama 2 | bnb.nf4 | 1 | 6,553,600 | 2 x A100 | N/A | - |
| 13 B | Llama 2 | bnb.nf4-dq | 1 | 6,553,600 | 2 x A100 | N/A | - |
| | | | | | | | |
| 13 B | Llama 2 | None | 1 | 6,553,600 | 4 x A100 | 35.57 GB | 10.25 min |
| 40 B | Falcon | None | 1 | 12,042,240 | 4 x A100 | OOM | - |
## Single-GPU Inference
| Size | Model | Quantization | GPU | Max GPU RAM | Token/sec |
|-------|----------------|--------------|----------|-------------------------------------------|-----------|
| 1.3 B | phi-1.5 | None | 1 x A100 | 2.86 GB | 42.56 |
| 1.3 B | phi-1.5 | bnb.nf4 | 1 x A100 | 1.39 GB | 22.89 |
| 1.3 B | phi-1.5 | bnb.nf4-dq | 1 x A100 | 1.33 GB | 22.75 |
| | | | | | |
| 3 B | StableLM Alpha | None | 1 x A100 | 7.30 GB | 49.01 |
| 3 B | StableLM Alpha | bnb.nf4 | 1 x A100 | 3.20 GB | 29.04 |
| 3 B | StableLM Alpha | bnb.nf4-dq | 1 x A100 | 3.04 GB | 27.15 |
| | | | | | |
| 7 B | Llama 2 | None | 1 x A100 | 13.52 GB | 30.97 |
| 7 B | Llama 2 | bnb.nf4 | 1 x A100 | 4.57 GB | 19.98 |
| 7 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 4.26 GB | 17.3 |
| | | | | | |
| 13 B | Llama 2 | None | 1 x A100 | 26.21 GB | 24.82 |
| 13 B | Llama 2 | bnb.nf4 | 1 x A100 | 8.32 GB | 16.73 |
| 13 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 7.72 GB | 14.43 |
| | | | | | |
| 34 B | CodeLlama | None | 1 x A100 | OOM | - |
| 34 B | CodeLlama | bnb.nf4 | 1 x A100 | 20.52 GB | 14.32 |
| 34 B | CodeLlama | bnb.nf4-dq | 1 x A100 | 18.95 GB | 12.37 |
| | | | | | |
| 40 B | Falcon | None | 1 x A100 | OOM | - |
| 40 B | Falcon | bnb.nf4 | 1 x A100 | 26.55 GB | 13.25 |
| 40 B | Falcon | bnb.nf4-dq | 1 x A100 | 24.63 GB | 11.64 |
| | | | | | |
| 70 B | Llama 2 | None | 1 x A100 | OOM | - |
| 70 B | Llama 2 | bnb.nf4 | 1 x A100 | CUDA error: CUBLAS_STATUS_NOT_INITIALIZED | - |
| 70 B | Llama 2 | bnb.nf4-dq | 1 x A100 | 37.21 GB | 7.97 |